From 8eadf69c78345bfb2cabfa196272c8ea0eb4ae77 Mon Sep 17 00:00:00 2001 From: Mikhail Chernyavsky Date: Mon, 5 Aug 2019 16:07:55 +0300 Subject: [PATCH 1/3] Cleanup --- .../rust/ide/annotator/RsErrorAnnotator.kt | 74 ++++++++++++++----- .../rust/ide/presentation/RsPsiRendering.kt | 24 ++++-- .../org/rust/lang/core/resolve/ImplLookup.kt | 1 + .../rust/lang/core/resolve/NameResolution.kt | 14 +++- .../core/resolve/ref/RsPathReferenceImpl.kt | 2 +- .../rust/lang/core/types/infer/Fulfillment.kt | 2 +- .../lang/core/types/infer/TypeInference.kt | 7 +- .../core/types/infer/TypeInferenceWalker.kt | 21 ++++-- .../org/rust/lang/core/types/ty/TyAdt.kt | 3 +- .../org/rust/lang/core/types/ty/TyInfer.kt | 1 + 10 files changed, 110 insertions(+), 39 deletions(-) diff --git a/src/main/kotlin/org/rust/ide/annotator/RsErrorAnnotator.kt b/src/main/kotlin/org/rust/ide/annotator/RsErrorAnnotator.kt index 64406133d3b..309847d8455 100644 --- a/src/main/kotlin/org/rust/ide/annotator/RsErrorAnnotator.kt +++ b/src/main/kotlin/org/rust/ide/annotator/RsErrorAnnotator.kt @@ -128,9 +128,15 @@ class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension { val declarationFieldsAmount = declaration.fields.size val bodyFieldsAmount = patTupleStruct.patList.size if (bodyFieldsAmount < declarationFieldsAmount && patTupleStruct.patRest == null) { - RsDiagnostic.MissingFieldsInTuplePattern(patTupleStruct, declaration, declarationFieldsAmount, bodyFieldsAmount).addToHolder(holder) + RsDiagnostic.MissingFieldsInTuplePattern( + patTupleStruct, + declaration, + declarationFieldsAmount, + bodyFieldsAmount + ).addToHolder(holder) } else if (bodyFieldsAmount > declarationFieldsAmount) { - RsDiagnostic.ExtraFieldInTupleStructPattern(patTupleStruct, bodyFieldsAmount, declarationFieldsAmount).addToHolder(holder) + RsDiagnostic.ExtraFieldInTupleStructPattern(patTupleStruct, bodyFieldsAmount, declarationFieldsAmount) + .addToHolder(holder) } } @@ -290,19 +296,25 @@ class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension { val error = when { element is RsNamedFieldDecl -> { val structName = element.ancestorStrict()?.crateRelativePath?.removePrefix("::") ?: "" - RsDiagnostic.StructFieldAccessError(ref, ref.referenceName, structName, - MakePublicFix.createIfCompatible(element, element.name, withinOneCrate)) + RsDiagnostic.StructFieldAccessError( + ref, ref.referenceName, structName, + MakePublicFix.createIfCompatible(element, element.name, withinOneCrate) + ) } - ref is RsMethodCall -> RsDiagnostic.AccessError(ref.identifier, RsErrorCode.E0624, "Method", - MakePublicFix.createIfCompatible(element, ref.referenceName, withinOneCrate)) + ref is RsMethodCall -> RsDiagnostic.AccessError( + ref.identifier, RsErrorCode.E0624, "Method", + MakePublicFix.createIfCompatible(element, ref.referenceName, withinOneCrate) + ) else -> { val itemType = when (element) { is RsItemElement -> element.itemKindName.capitalize() else -> "Item" } - RsDiagnostic.AccessError(ref, RsErrorCode.E0603, itemType, - MakePublicFix.createIfCompatible(element, ref.referenceName, withinOneCrate)) + RsDiagnostic.AccessError( + ref, RsErrorCode.E0603, itemType, + MakePublicFix.createIfCompatible(element, ref.referenceName, withinOneCrate) + ) } } error.addToHolder(holder) @@ -327,7 +339,8 @@ class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension { if (type.underscore == null) return val owner = type.owner.parent if ((owner is RsValueParameter && owner.parent.parent is RsFunction) - || (owner is RsRetType && owner.parent is RsFunction) || owner is RsConstant) { + || (owner is RsRetType && owner.parent is RsFunction) || owner is RsConstant + ) { RsDiagnostic.TypePlaceholderForbiddenError(type).addToHolder(holder) } } @@ -348,7 +361,10 @@ class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension { private fun checkPath(holder: RsAnnotationHolder, path: RsPath) { val qualifier = path.path if ((qualifier == null || isValidSelfSuperPrefix(qualifier)) && !isValidSelfSuperPrefix(path)) { - holder.createErrorAnnotation(path.referenceNameElement, "Invalid path: self and super are allowed only at the beginning") + holder.createErrorAnnotation( + path.referenceNameElement, + "Invalid path: self and super are allowed only at the beginning" + ) return } @@ -539,7 +555,12 @@ class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension { } // E0120: Drop can be only implemented by structs and enums - private fun checkImplDropForNonAdtError(holder: RsAnnotationHolder, impl: RsImplItem, traitRef: RsTraitRef, trait: RsTraitItem) { + private fun checkImplDropForNonAdtError( + holder: RsAnnotationHolder, + impl: RsImplItem, + traitRef: RsTraitRef, + trait: RsTraitItem + ) { if (trait != trait.knownItems.Drop) return if (impl.typeReference?.type is TyAdt?) return @@ -559,7 +580,12 @@ class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension { } // E0184: Cannot implement both Copy and Drop - private fun checkImplBothCopyAndDrop(holder: RsAnnotationHolder, self: Ty, element: PsiElement, trait: RsTraitItem) { + private fun checkImplBothCopyAndDrop( + holder: RsAnnotationHolder, + self: Ty, + element: PsiElement, + trait: RsTraitItem + ) { val oppositeTrait = when (trait) { trait.knownItems.Drop -> trait.knownItems.Copy trait.knownItems.Copy -> trait.knownItems.Drop @@ -685,7 +711,11 @@ class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension { RsDiagnostic.InvalidStartAttrError.InvalidParam(params[0].typeReference ?: params[0], 0) .addToHolder(holder) } - if (params[1].typeReference?.type != TyPointer(TyPointer(TyInteger.U8, Mutability.IMMUTABLE), Mutability.IMMUTABLE)) { + if (params[1].typeReference?.type != TyPointer( + TyPointer(TyInteger.U8, Mutability.IMMUTABLE), + Mutability.IMMUTABLE + ) + ) { RsDiagnostic.InvalidStartAttrError.InvalidParam(params[1].typeReference ?: params[1], 1) .addToHolder(holder) } @@ -771,8 +801,10 @@ class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension { val dotdoteq = range.dotdoteq ?: range.dotdotdot ?: return if (dotdoteq == range.dotdotdot) { // rustc doesn't have an error code for this ("error: unexpected token: `...`") - holder.createErrorAnnotation(dotdoteq, - "`...` syntax is deprecated. Use `..` for an exclusive range or `..=` for an inclusive range") + holder.createErrorAnnotation( + dotdoteq, + "`...` syntax is deprecated. Use `..` for an exclusive range or `..=` for an inclusive range" + ) return } val expr = range.exprList.singleOrNull() ?: return @@ -844,7 +876,12 @@ private fun RsExpr?.isComparisonBinaryExpr(): Boolean { return op is ComparisonOp || op is EqualityOp } -private fun checkDuplicates(holder: RsAnnotationHolder, element: RsNameIdentifierOwner, scope: PsiElement = element.parent, recursively: Boolean = false) { +private fun checkDuplicates( + holder: RsAnnotationHolder, + element: RsNameIdentifierOwner, + scope: PsiElement = element.parent, + recursively: Boolean = false +) { if (element.isCfgUnknown) return val owner = if (scope is RsMembers) scope.parent else scope val duplicates = holder.currentAnnotationSession.duplicatesByNamespace(scope, recursively) @@ -897,7 +934,10 @@ private fun PsiElement.nameOrImportedName(): String? = else -> null } -private fun AnnotationSession.duplicatesByNamespace(owner: PsiElement, recursively: Boolean): Map> { +private fun AnnotationSession.duplicatesByNamespace( + owner: PsiElement, + recursively: Boolean +): Map> { if (owner.parent is RsFnPointerType) return emptyMap() fun PsiElement.namespaced(): Sequence> = diff --git a/src/main/kotlin/org/rust/ide/presentation/RsPsiRendering.kt b/src/main/kotlin/org/rust/ide/presentation/RsPsiRendering.kt index 2625b538742..f05243de1d4 100644 --- a/src/main/kotlin/org/rust/ide/presentation/RsPsiRendering.kt +++ b/src/main/kotlin/org/rust/ide/presentation/RsPsiRendering.kt @@ -14,18 +14,26 @@ import org.rust.lang.core.types.type import org.rust.stdext.joinToWithBuffer /** Return text of the element without switching to AST (loses non-stubbed parts of PSI) */ -fun RsTypeReference.getStubOnlyText(subst: Substitution = emptySubstitution, renderLifetimes: Boolean = true): String = - renderTypeReference(this, subst, renderLifetimes) +fun RsTypeReference.getStubOnlyText( + subst: Substitution = emptySubstitution, + renderLifetimes: Boolean = true +): String = renderTypeReference(this, subst, renderLifetimes) /** Return text of the element without switching to AST (loses non-stubbed parts of PSI) */ -fun RsValueParameterList.getStubOnlyText(subst: Substitution = emptySubstitution, renderLifetimes: Boolean = true): String = - renderValueParameterList(this, subst, renderLifetimes) +fun RsValueParameterList.getStubOnlyText( + subst: Substitution = emptySubstitution, + renderLifetimes: Boolean = true +): String = renderValueParameterList(this, subst, renderLifetimes) /** Return text of the element without switching to AST (loses non-stubbed parts of PSI) */ fun RsTraitRef.getStubOnlyText(subst: Substitution = emptySubstitution, renderLifetimes: Boolean = true): String = buildString { appendPath(path, subst, renderLifetimes) } -private fun renderValueParameterList(list: RsValueParameterList, subst: Substitution, renderLifetimes: Boolean): String { +private fun renderValueParameterList( + list: RsValueParameterList, + subst: Substitution, + renderLifetimes: Boolean +): String { return buildString { append("(") val selfParameter = list.selfParameter @@ -186,7 +194,11 @@ private fun StringBuilder.appendRetType(retType: RsRetType?, subst: Substitution } } -private fun StringBuilder.appendValueParameterListTypes(list: List, subst: Substitution, renderLifetimes: Boolean) { +private fun StringBuilder.appendValueParameterListTypes( + list: List, + subst: Substitution, + renderLifetimes: Boolean +) { list.joinToWithBuffer(this, separator = ", ", prefix = "(", postfix = ")") { sb -> typeReference?.let { sb.appendTypeReference(it, subst, renderLifetimes) } } diff --git a/src/main/kotlin/org/rust/lang/core/resolve/ImplLookup.kt b/src/main/kotlin/org/rust/lang/core/resolve/ImplLookup.kt index 812506424c9..1e495b45d1c 100644 --- a/src/main/kotlin/org/rust/lang/core/resolve/ImplLookup.kt +++ b/src/main/kotlin/org/rust/lang/core/resolve/ImplLookup.kt @@ -958,6 +958,7 @@ private sealed class SelectionCandidate { object TraitObject : SelectionCandidate() /** @see ImplLookup.getHardcodedImpls */ object HardcodedImpl : SelectionCandidate() + object Closure : SelectionCandidate() class Projection(val bound: TraitRef) : SelectionCandidate() } diff --git a/src/main/kotlin/org/rust/lang/core/resolve/NameResolution.kt b/src/main/kotlin/org/rust/lang/core/resolve/NameResolution.kt index 841c1fb9608..828a6cde76e 100644 --- a/src/main/kotlin/org/rust/lang/core/resolve/NameResolution.kt +++ b/src/main/kotlin/org/rust/lang/core/resolve/NameResolution.kt @@ -1171,9 +1171,15 @@ private fun processAssociatedItems( processor: (AssocItemScopeEntry) -> Boolean ): Boolean { val nsFilter: (RsAbstractable) -> Boolean = when { - Namespace.Types in ns && Namespace.Values in ns -> {{ true }} - Namespace.Types in ns -> {{ it is RsTypeAlias }} - Namespace.Values in ns -> {{ it !is RsTypeAlias }} + Namespace.Types in ns && Namespace.Values in ns -> { + { true } + } + Namespace.Types in ns -> { + { it is RsTypeAlias } + } + Namespace.Values in ns -> { + { it !is RsTypeAlias } + } else -> return false } @@ -1428,7 +1434,7 @@ private fun makeHygieneFilter(anchor: PsiElement): (RsPatBinding) -> Boolean { val nameIdentifier = binding.nameIdentifier ?: return false val bindingHygienicScope = (nameIdentifier.findMacroCallFromWhichLeafIsExpanded() ?: nameIdentifier).containingFile - .unwrapCodeFragments() + .unwrapCodeFragments() return anchorHygienicScope == bindingHygienicScope } } diff --git a/src/main/kotlin/org/rust/lang/core/resolve/ref/RsPathReferenceImpl.kt b/src/main/kotlin/org/rust/lang/core/resolve/ref/RsPathReferenceImpl.kt index 4ea410992db..6177d913fa4 100644 --- a/src/main/kotlin/org/rust/lang/core/resolve/ref/RsPathReferenceImpl.kt +++ b/src/main/kotlin/org/rust/lang/core/resolve/ref/RsPathReferenceImpl.kt @@ -120,7 +120,7 @@ fun resolvePath(path: RsPath, lookup: ImplLookup? = null): List instantiatePathGenerics( +fun instantiatePathGenerics( path: RsPath, resolved: BoundElement ): BoundElement { diff --git a/src/main/kotlin/org/rust/lang/core/types/infer/Fulfillment.kt b/src/main/kotlin/org/rust/lang/core/types/infer/Fulfillment.kt index 3a5aff43428..f5c6808609e 100644 --- a/src/main/kotlin/org/rust/lang/core/types/infer/Fulfillment.kt +++ b/src/main/kotlin/org/rust/lang/core/types/infer/Fulfillment.kt @@ -171,7 +171,7 @@ class FulfillmentContext(val ctx: RsInferenceContext, val lookup: ImplLookup) { private fun processPredicate(pendingObligation: PendingPredicateObligation): ProcessPredicateResult { val (obligation, stalledOn) = pendingObligation - if (!stalledOn.isEmpty()) { + if (stalledOn.isNotEmpty()) { val nothingChanged = stalledOn.all { val resolvedTy = ctx.shallowResolve(it) resolvedTy == it diff --git a/src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt b/src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt index 50d3cede6e0..14bacf02690 100644 --- a/src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt +++ b/src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt @@ -372,7 +372,7 @@ class RsInferenceContext( } fun canCombineTypes(ty1: Ty, ty2: Ty): Boolean { - return probe { combineTypesResolved(shallowResolve(ty1), shallowResolve(ty2)).isOk } + return probe { combineTypes(ty1, ty2).isOk } } fun combineTypesIfOk(ty1: Ty, ty2: Ty): Boolean { @@ -569,7 +569,10 @@ class RsInferenceContext( optNormalizeProjectionTypeResolved(resolveTypeVarsIfPossible(projectionTy) as TyProjection, recursionDepth) /** See [optNormalizeProjectionType] */ - private fun optNormalizeProjectionTypeResolved(projectionTy: TyProjection, recursionDepth: Int): TyWithObligations? { + private fun optNormalizeProjectionTypeResolved( + projectionTy: TyProjection, + recursionDepth: Int + ): TyWithObligations? { if (projectionTy.type is TyInfer.TyVar) return null return when (val cacheResult = projectionCache.tryStart(projectionTy)) { diff --git a/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt b/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt index 816983852a0..70cf61fb5cd 100644 --- a/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt +++ b/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt @@ -196,9 +196,12 @@ class RsTypeInferenceWalker( fun inferTypeCoercableTo(expr: RsExpr, expected: Ty): Ty = expr.inferTypeCoercableTo(expected) - private fun coerce(element: RsElement, inferred: Ty, expected: Ty): Boolean { - return coerceResolved(element, resolveTypeVarsWithObligations(inferred), resolveTypeVarsWithObligations(expected)) - } + private fun coerce(element: RsElement, inferred: Ty, expected: Ty): Boolean = + coerceResolved( + element, + resolveTypeVarsWithObligations(inferred), + resolveTypeVarsWithObligations(expected) + ) private fun coerceResolved(element: RsElement, inferred: Ty, expected: Ty): Boolean { when (val result = tryCoerce(inferred, expected)) { @@ -647,7 +650,11 @@ class RsTypeInferenceWalker( } } - private fun pickSingleMethod(receiver: Ty, variants: List, methodCall: RsMethodCall): MethodResolveVariant? { + private fun pickSingleMethod( + receiver: Ty, + variants: List, + methodCall: RsMethodCall + ): MethodResolveVariant? { val filtered = filterAssocItems(variants, methodCall).singleOrLet { list -> // 3. Pick results matching receiver type TypeInferenceMarks.methodPickDerefOrder.hit() @@ -886,8 +893,10 @@ class RsTypeInferenceWalker( if (op is OverloadableBinaryOperator) { val rhsTypeVar = TyInfer.TyVar() enforceOverloadedBinopTypes(lhsType, rhsTypeVar, op) - val rhsType = resolveTypeVarsWithObligations(expr.right?.inferTypeCoercableTo(rhsTypeVar) - ?: TyUnknown) + val rhsType = resolveTypeVarsWithObligations( + expr.right?.inferTypeCoercableTo(rhsTypeVar) + ?: TyUnknown + ) val lhsAdjustment = Adjustment.BorrowReference(TyReference(lhsType, Mutability.IMMUTABLE)) ctx.addAdjustment(expr.left, lhsAdjustment) diff --git a/src/main/kotlin/org/rust/lang/core/types/ty/TyAdt.kt b/src/main/kotlin/org/rust/lang/core/types/ty/TyAdt.kt index 55e20bfbc5d..4a090dbd62d 100644 --- a/src/main/kotlin/org/rust/lang/core/types/ty/TyAdt.kt +++ b/src/main/kotlin/org/rust/lang/core/types/ty/TyAdt.kt @@ -32,8 +32,7 @@ data class TyAdt private constructor( val aliasedBy: BoundElement? ) : Ty(mergeFlags(typeArguments) or mergeFlags(regionArguments)) { - // This method is rarely called (in comparison with folding), - // so we can implement it in a such inefficient way + // This method is rarely called (in comparison with folding), so we can implement it in a such inefficient way. override val typeParameterValues: Substitution get() { val typeSubst = item.typeParameters.withIndex().associate { (i, param) -> diff --git a/src/main/kotlin/org/rust/lang/core/types/ty/TyInfer.kt b/src/main/kotlin/org/rust/lang/core/types/ty/TyInfer.kt index e58376ace1a..33b86cb4e9b 100644 --- a/src/main/kotlin/org/rust/lang/core/types/ty/TyInfer.kt +++ b/src/main/kotlin/org/rust/lang/core/types/ty/TyInfer.kt @@ -16,6 +16,7 @@ sealed class TyInfer : Ty(HAS_TY_INFER_MASK) { val origin: Ty? = null, override var parent: NodeOrValue = VarValue(null, 0) ) : TyInfer(), Node + class IntVar(override var parent: NodeOrValue = VarValue(null, 0)) : TyInfer(), Node class FloatVar(override var parent: NodeOrValue = VarValue(null, 0)) : TyInfer(), Node } From 1fb5dd46b04a1cee37a7ff0d6b3e660cb54f4861 Mon Sep 17 00:00:00 2001 From: Mikhail Chernyavsky Date: Wed, 25 Dec 2019 18:06:55 +0300 Subject: [PATCH 2/3] TY: Initial type inference for const generics --- .../rust/ide/annotator/RsErrorAnnotator.kt | 11 +- .../hints/RsTypeHintsPresentationFactory.kt | 52 +++-- .../inspections/checkMatch/CheckMatchUtils.kt | 8 +- .../ide/inspections/checkMatch/Constructor.kt | 29 +-- .../ide/inspections/checkMatch/PatternKind.kt | 6 +- .../SpecifyTypeExplicitlyIntention.kt | 7 +- .../rust/ide/presentation/RsPsiRendering.kt | 44 ++-- .../rust/ide/presentation/TypeRendering.kt | 27 ++- .../rust/ide/utils/BooleanExprSimplifier.kt | 7 +- .../lang/core/completion/LookupElements.kt | 4 +- .../org/rust/lang/core/psi/ext/RsArrayType.kt | 10 +- .../org/rust/lang/core/psi/ext/RsBlock.kt | 13 +- .../org/rust/lang/core/psi/ext/RsBlockExpr.kt | 7 +- .../org/rust/lang/core/psi/ext/RsTraitItem.kt | 7 +- .../org/rust/lang/core/resolve/ImplLookup.kt | 59 +++-- .../rust/lang/core/resolve/NameResolution.kt | 22 +- .../lang/core/resolve/RsCachedImplItem.kt | 6 +- .../core/resolve/ref/RsPathReferenceImpl.kt | 22 +- .../lang/core/stubs/StubImplementations.kt | 41 +++- .../org/rust/lang/core/types/BoundElement.kt | 14 +- .../org/rust/lang/core/types/Extensions.kt | 7 +- .../kotlin/org/rust/lang/core/types/Kind.kt | 5 +- .../org/rust/lang/core/types/Substitution.kt | 42 +++- .../org/rust/lang/core/types/consts/Const.kt | 23 ++ .../core/types/consts/CtConstParameter.kt | 20 ++ .../rust/lang/core/types/consts/CtInfer.kt | 18 ++ .../lang/core/types/consts/CtUnevaluated.kt | 21 ++ .../rust/lang/core/types/consts/CtUnknown.kt | 10 + .../rust/lang/core/types/consts/CtValue.kt | 22 ++ .../lang/core/types/infer/Declarations.kt | 20 +- .../org/rust/lang/core/types/infer/Fold.kt | 92 +++++++- .../lang/core/types/infer/TypeInference.kt | 179 ++++++++++++---- .../core/types/infer/TypeInferenceWalker.kt | 202 ++++++++++++------ .../lang/core/types/regions/ReEarlyBound.kt | 14 +- .../org/rust/lang/core/types/ty/TyAdt.kt | 34 ++- .../org/rust/lang/core/types/ty/TyArray.kt | 10 +- .../rust/lang/utils/evaluation/ConstExpr.kt | 92 ++++++++ .../lang/utils/evaluation/ConstExprBuilder.kt | 169 +++++++++++++++ .../utils/evaluation/ConstExprEvaluator.kt | 185 ++++++++++++++++ .../rust/lang/utils/evaluation/ExprValue.kt | 28 --- .../utils/evaluation/RsConstExprEvaluator.kt | 181 ---------------- .../org/rust/lang/utils/evaluation/Utils.kt | 19 ++ 42 files changed, 1304 insertions(+), 485 deletions(-) create mode 100644 src/main/kotlin/org/rust/lang/core/types/consts/Const.kt create mode 100644 src/main/kotlin/org/rust/lang/core/types/consts/CtConstParameter.kt create mode 100644 src/main/kotlin/org/rust/lang/core/types/consts/CtInfer.kt create mode 100644 src/main/kotlin/org/rust/lang/core/types/consts/CtUnevaluated.kt create mode 100644 src/main/kotlin/org/rust/lang/core/types/consts/CtUnknown.kt create mode 100644 src/main/kotlin/org/rust/lang/core/types/consts/CtValue.kt create mode 100644 src/main/kotlin/org/rust/lang/utils/evaluation/ConstExpr.kt create mode 100644 src/main/kotlin/org/rust/lang/utils/evaluation/ConstExprBuilder.kt create mode 100644 src/main/kotlin/org/rust/lang/utils/evaluation/ConstExprEvaluator.kt delete mode 100644 src/main/kotlin/org/rust/lang/utils/evaluation/ExprValue.kt delete mode 100644 src/main/kotlin/org/rust/lang/utils/evaluation/RsConstExprEvaluator.kt create mode 100644 src/main/kotlin/org/rust/lang/utils/evaluation/Utils.kt diff --git a/src/main/kotlin/org/rust/ide/annotator/RsErrorAnnotator.kt b/src/main/kotlin/org/rust/ide/annotator/RsErrorAnnotator.kt index 309847d8455..a0ae9ead830 100644 --- a/src/main/kotlin/org/rust/ide/annotator/RsErrorAnnotator.kt +++ b/src/main/kotlin/org/rust/ide/annotator/RsErrorAnnotator.kt @@ -31,12 +31,12 @@ import org.rust.lang.core.resolve.knownItems import org.rust.lang.core.resolve.namespaces import org.rust.lang.core.resolve.ref.deepResolve import org.rust.lang.core.types.* +import org.rust.lang.core.types.consts.asLong import org.rust.lang.core.types.ty.* import org.rust.lang.utils.RsDiagnostic import org.rust.lang.utils.RsErrorCode import org.rust.lang.utils.addToHolder -import org.rust.lang.utils.evaluation.ExprValue -import org.rust.lang.utils.evaluation.RsConstExprEvaluator +import org.rust.lang.utils.evaluation.evaluate class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension { override fun isForceHighlightParents(file: PsiFile): Boolean = file is RsFile @@ -211,12 +211,7 @@ class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension { val indexToVariantMap = hashMapOf() for (variant in o.enumVariantList) { val expr = variant.variantDiscriminant?.expr - val int = if (expr != null) { - val result = RsConstExprEvaluator.evaluate(expr, reprType) as? ExprValue.Integer - result?.value ?: return - } else { - null - } + val int = if (expr != null) expr.evaluate(reprType).asLong() ?: return else null val idx = int ?: discrCounter discrCounter = idx + 1 diff --git a/src/main/kotlin/org/rust/ide/hints/RsTypeHintsPresentationFactory.kt b/src/main/kotlin/org/rust/ide/hints/RsTypeHintsPresentationFactory.kt index 69225c37cbc..a495fc00b02 100644 --- a/src/main/kotlin/org/rust/ide/hints/RsTypeHintsPresentationFactory.kt +++ b/src/main/kotlin/org/rust/ide/hints/RsTypeHintsPresentationFactory.kt @@ -13,6 +13,9 @@ import org.rust.lang.core.psi.RsTypeAlias import org.rust.lang.core.psi.RsTypeParameter import org.rust.lang.core.psi.ext.typeParameters import org.rust.lang.core.types.BoundElement +import org.rust.lang.core.types.Kind +import org.rust.lang.core.types.consts.CtConstParameter +import org.rust.lang.core.types.consts.CtValue import org.rust.lang.core.types.ty.* import org.rust.lang.core.types.type @@ -22,19 +25,22 @@ class RsTypeHintsPresentationFactory(private val factory: PresentationFactory, p listOf(text(": "), hint(type, 1)).join() ) - private fun hint(type: Ty, level: Int): InlayPresentation = when (type) { - is TyTuple -> tupleTypeHint(type, level) - is TyAdt -> adtTypeHint(type, level) - is TyFunction -> functionTypeHint(type, level) - is TyReference -> referenceTypeHint(type, level) - is TyPointer -> pointerTypeHint(type, level) - is TyProjection -> projectionTypeHint(type, level) - is TyTypeParameter -> typeParameterTypeHint(type) - is TyArray -> arrayTypeHint(type, level) - is TySlice -> sliceTypeHint(type, level) - is TyTraitObject -> traitObjectTypeHint(type, level) - is TyAnon -> anonTypeHint(type, level) - else -> text(type.shortPresentableText) + private fun hint(kind: Kind, level: Int): InlayPresentation = when (kind) { + is TyTuple -> tupleTypeHint(kind, level) + is TyAdt -> adtTypeHint(kind, level) + is TyFunction -> functionTypeHint(kind, level) + is TyReference -> referenceTypeHint(kind, level) + is TyPointer -> pointerTypeHint(kind, level) + is TyProjection -> projectionTypeHint(kind, level) + is TyTypeParameter -> typeParameterTypeHint(kind) + is TyArray -> arrayTypeHint(kind, level) + is TySlice -> sliceTypeHint(kind, level) + is TyTraitObject -> traitObjectTypeHint(kind, level) + is TyAnon -> anonTypeHint(kind, level) + is CtConstParameter -> constParameterTypeHint(kind) + is CtValue -> text(kind.expr.toString()) + is Ty -> text(kind.shortPresentableText) + else -> text(null) } private fun functionTypeHint(type: TyFunction, level: Int): InlayPresentation { @@ -88,22 +94,25 @@ class RsTypeHintsPresentationFactory(private val factory: PresentationFactory, p val typeArguments = alias?.typeParameters?.map { (aliasedBy.subst[it] ?: TyUnknown) to it } ?: type.typeArguments.zip(type.item.typeParameters) - val userVisibleTypeArguments = mutableListOf() + val userVisibleKindArguments = mutableListOf() for ((argument, parameter) in typeArguments) { if (!showObviousTypes && isDefaultTypeParameter(argument, parameter)) { // don't show default types continue } - userVisibleTypeArguments.add(argument) + userVisibleKindArguments.add(argument) + } + for (argument in type.constArguments) { + userVisibleKindArguments.add(argument) } - if (userVisibleTypeArguments.isNotEmpty()) { + if (userVisibleKindArguments.isNotEmpty()) { val collapsible = factory.collapsible( prefix = text("<"), collapsed = text(PLACEHOLDER), - expanded = { parametersHint(userVisibleTypeArguments, level + 1) }, + expanded = { parametersHint(userVisibleKindArguments, level + 1) }, suffix = text(">"), - startWithPlaceholder = checkSize(level, userVisibleTypeArguments.size) + startWithPlaceholder = checkSize(level, userVisibleKindArguments.size) ) return listOf(typeNamePresentation, collapsible).join() } @@ -147,6 +156,9 @@ class RsTypeHintsPresentationFactory(private val factory: PresentationFactory, p return text(parameter.name) } + private fun constParameterTypeHint(const: CtConstParameter): InlayPresentation = + factory.psiSingleReference(text(const.parameter.name)) { const.parameter } + private fun arrayTypeHint(type: TyArray, level: Int): InlayPresentation = factory.collapsible( prefix = text("["), @@ -187,8 +199,8 @@ class RsTypeHintsPresentationFactory(private val factory: PresentationFactory, p startWithPlaceholder = checkSize(level, type.traits.size) ) - private fun parametersHint(types: List, level: Int): InlayPresentation = - types.map { hint(it, level) }.join(", ") + private fun parametersHint(kinds: List, level: Int): InlayPresentation = + kinds.map { hint(it, level) }.join(", ") private fun traitItemTypeHint( trait: BoundElement, diff --git a/src/main/kotlin/org/rust/ide/inspections/checkMatch/CheckMatchUtils.kt b/src/main/kotlin/org/rust/ide/inspections/checkMatch/CheckMatchUtils.kt index 17bce5040a3..b2a43c0c6e3 100644 --- a/src/main/kotlin/org/rust/ide/inspections/checkMatch/CheckMatchUtils.kt +++ b/src/main/kotlin/org/rust/ide/inspections/checkMatch/CheckMatchUtils.kt @@ -7,12 +7,13 @@ package org.rust.ide.inspections.checkMatch import org.rust.lang.core.psi.* import org.rust.lang.core.psi.ext.* +import org.rust.lang.core.types.consts.CtValue import org.rust.lang.core.types.ty.Ty import org.rust.lang.core.types.ty.TyAdt import org.rust.lang.core.types.ty.TyUnknown import org.rust.lang.core.types.type -import org.rust.lang.utils.evaluation.ExprValue -import org.rust.lang.utils.evaluation.RsConstExprEvaluator +import org.rust.lang.utils.evaluation.ConstExpr.Value +import org.rust.lang.utils.evaluation.evaluate class CheckMatchException(message: String) : Exception(message) @@ -44,7 +45,8 @@ val Matrix.firstColumnType: Ty fun List.calculateMatrix(): Matrix = flatMap { arm -> arm.patList.map { listOf(it.lower) } } -private val RsExpr.value: ExprValue? get() = RsConstExprEvaluator.evaluate(this) +private val RsExpr.value: Value<*>? + get() = (evaluate() as? CtValue)?.expr // lower_pattern_unadjusted private val RsPat.kind: PatternKind diff --git a/src/main/kotlin/org/rust/ide/inspections/checkMatch/Constructor.kt b/src/main/kotlin/org/rust/ide/inspections/checkMatch/Constructor.kt index 32a66c3f3fd..374ed5b8956 100644 --- a/src/main/kotlin/org/rust/ide/inspections/checkMatch/Constructor.kt +++ b/src/main/kotlin/org/rust/ide/inspections/checkMatch/Constructor.kt @@ -13,21 +13,21 @@ import org.rust.lang.core.psi.ext.fieldTypes import org.rust.lang.core.psi.ext.size import org.rust.lang.core.psi.ext.variants import org.rust.lang.core.types.ty.* -import org.rust.lang.utils.evaluation.ExprValue +import org.rust.lang.utils.evaluation.ConstExpr.Value sealed class Constructor { /** The constructor of all patterns that don't vary by constructor, e.g. struct patterns and fixed-length arrays */ object Single : Constructor() { - override fun coveredByRange(from: ExprValue, to: ExprValue, included: Boolean): Boolean = true + override fun coveredByRange(from: Value<*>, to: Value<*>, included: Boolean): Boolean = true } /** Enum variants */ data class Variant(val variant: RsEnumVariant) : Constructor() /** Literal values */ - data class ConstantValue(val value: ExprValue) : Constructor() { - override fun coveredByRange(from: ExprValue, to: ExprValue, included: Boolean): Boolean = + data class ConstantValue(val value: Value<*>) : Constructor() { + override fun coveredByRange(from: Value<*>, to: Value<*>, included: Boolean): Boolean = if (included) { value >= from && value <= to } else { @@ -36,8 +36,9 @@ sealed class Constructor { } /** Ranges of literal values (`2..=5` and `2..5`) */ - data class ConstantRange(val start: ExprValue, val end: ExprValue, val includeEnd: Boolean = false) : Constructor() { - override fun coveredByRange(from: ExprValue, to: ExprValue, included: Boolean): Boolean = + data class ConstantRange(val start: Value<*>, val end: Value<*>, val includeEnd: Boolean = false) : + Constructor() { + override fun coveredByRange(from: Value<*>, to: Value<*>, included: Boolean): Boolean = if (includeEnd) { ((end < to) || (included && to == end)) && (start >= from) } else { @@ -68,7 +69,7 @@ sealed class Constructor { else -> 0 } - open fun coveredByRange(from: ExprValue, to: ExprValue, included: Boolean): Boolean = false + open fun coveredByRange(from: Value<*>, to: Value<*>, included: Boolean): Boolean = false fun subTypes(type: Ty): List = when (type) { is TyTuple -> type.types @@ -94,7 +95,7 @@ sealed class Constructor { companion object { fun allConstructors(ty: Ty): List = when { - ty is TyBool -> listOf(true, false).map { ConstantValue(ExprValue.Bool(it)) } + ty is TyBool -> listOf(true, false).map { ConstantValue(Value.Bool(it)) } ty is TyAdt && ty.item is RsEnumItem -> ty.item.variants.map { Variant(it) } @@ -107,13 +108,13 @@ sealed class Constructor { } } -private operator fun ExprValue.compareTo(other: ExprValue): Int { +private operator fun Value<*>.compareTo(other: Value<*>): Int { return when { - this is ExprValue.Bool && other is ExprValue.Bool -> value.compareTo(other.value) - this is ExprValue.Integer && other is ExprValue.Integer -> value.compareTo(other.value) - this is ExprValue.Float && other is ExprValue.Float -> value.compareTo(other.value) - this is ExprValue.Str && other is ExprValue.Str -> value.compareTo(other.value) - this is ExprValue.Char && other is ExprValue.Char -> value.compareTo(other.value) + this is Value.Bool && other is Value.Bool -> value.compareTo(other.value) + this is Value.Integer && other is Value.Integer -> value.compareTo(other.value) + this is Value.Float && other is Value.Float -> value.compareTo(other.value) + this is Value.Str && other is Value.Str -> value.compareTo(other.value) + this is Value.Char && other is Value.Char -> value.compareTo(other.value) else -> throw CheckMatchException("Comparison of incompatible types: $javaClass and ${other.javaClass}") } } diff --git a/src/main/kotlin/org/rust/ide/inspections/checkMatch/PatternKind.kt b/src/main/kotlin/org/rust/ide/inspections/checkMatch/PatternKind.kt index 881864d2530..c24a198535f 100644 --- a/src/main/kotlin/org/rust/ide/inspections/checkMatch/PatternKind.kt +++ b/src/main/kotlin/org/rust/ide/inspections/checkMatch/PatternKind.kt @@ -8,7 +8,7 @@ package org.rust.ide.inspections.checkMatch import org.rust.lang.core.psi.RsEnumItem import org.rust.lang.core.psi.RsEnumVariant import org.rust.lang.core.types.ty.Ty -import org.rust.lang.utils.evaluation.ExprValue +import org.rust.lang.utils.evaluation.ConstExpr.Value sealed class PatternKind { object Wild : PatternKind() @@ -25,9 +25,9 @@ sealed class PatternKind { /** &P, &mut P, etc */ data class Deref(val subPattern: Pattern) : PatternKind() - data class Const(val value: ExprValue) : PatternKind() + data class Const(val value: Value<*>) : PatternKind() - data class Range(val lc: ExprValue, val rc: ExprValue, val isInclusive: Boolean) : PatternKind() + data class Range(val lc: Value<*>, val rc: Value<*>, val isInclusive: Boolean) : PatternKind() interface SliceField { diff --git a/src/main/kotlin/org/rust/ide/intentions/SpecifyTypeExplicitlyIntention.kt b/src/main/kotlin/org/rust/ide/intentions/SpecifyTypeExplicitlyIntention.kt index 3d7c54b2297..71427d9e438 100644 --- a/src/main/kotlin/org/rust/ide/intentions/SpecifyTypeExplicitlyIntention.kt +++ b/src/main/kotlin/org/rust/ide/intentions/SpecifyTypeExplicitlyIntention.kt @@ -14,6 +14,10 @@ import org.rust.lang.core.psi.RsLetDecl import org.rust.lang.core.psi.RsPsiFactory import org.rust.lang.core.psi.ext.ancestorStrict import org.rust.lang.core.psi.ext.startOffset +import org.rust.lang.core.types.consts.CtInferVar +import org.rust.lang.core.types.consts.CtUnevaluated +import org.rust.lang.core.types.consts.CtUnknown +import org.rust.lang.core.types.infer.containsConstOfClass import org.rust.lang.core.types.infer.containsTyOfClass import org.rust.lang.core.types.ty.Ty import org.rust.lang.core.types.ty.TyAnon @@ -34,7 +38,8 @@ class SpecifyTypeExplicitlyIntention : RsElementBaseIntentionAction= initializer.startOffset - 1) return null val pat = letDecl.pat ?: return null val type = pat.type - if (type.containsTyOfClass(listOf(TyUnknown::class.java, TyInfer::class.java, TyAnon::class.java))) { + if (type.containsTyOfClass(TyUnknown::class.java, TyInfer::class.java, TyAnon::class.java) + || type.containsConstOfClass(CtUnknown::class.java, CtInferVar::class.java, CtUnevaluated::class.java)) { return null } return Context(type, letDecl) diff --git a/src/main/kotlin/org/rust/ide/presentation/RsPsiRendering.kt b/src/main/kotlin/org/rust/ide/presentation/RsPsiRendering.kt index f05243de1d4..44be5656ade 100644 --- a/src/main/kotlin/org/rust/ide/presentation/RsPsiRendering.kt +++ b/src/main/kotlin/org/rust/ide/presentation/RsPsiRendering.kt @@ -8,9 +8,13 @@ package org.rust.ide.presentation import org.rust.lang.core.psi.* import org.rust.lang.core.psi.ext.* import org.rust.lang.core.types.Substitution +import org.rust.lang.core.types.consts.CtConstParameter +import org.rust.lang.core.types.consts.CtValue import org.rust.lang.core.types.emptySubstitution +import org.rust.lang.core.types.infer.substitute import org.rust.lang.core.types.ty.TyTypeParameter import org.rust.lang.core.types.type +import org.rust.lang.utils.evaluation.evaluate import org.rust.stdext.joinToWithBuffer /** Return text of the element without switching to AST (loses non-stubbed parts of PSI) */ @@ -111,7 +115,7 @@ private fun StringBuilder.appendTypeReference(ref: RsTypeReference, subst: Subst type.typeReference?.let { appendTypeReference(it, subst, renderLifetimes) } if (!type.isSlice) { append("; ") - append(type.arraySize) // may trigger resolve + append(type.arraySize ?: "{}") // may trigger resolve } append("]") } @@ -151,29 +155,37 @@ private fun StringBuilder.appendPath(path: RsPath, subst: Substitution, renderLi val inAngles = path.typeArgumentList // Foo<...> val fnSugar = path.valueParameterList // &dyn FnOnce(...) -> i32 if (inAngles != null) { - val lifetimeList = inAngles.lifetimeList - val typeReferenceList = inAngles.typeReferenceList - val assocTypeBindingList = inAngles.assocTypeBindingList + val lifetimeArguments = inAngles.lifetimeList + val typeArguments = inAngles.typeReferenceList + val constArguments = inAngles.exprList + val assocTypeBindings = inAngles.assocTypeBindingList - val hasLifetimes = renderLifetimes && lifetimeList.isNotEmpty() - val hasTypeReferences = typeReferenceList.isNotEmpty() - val hasAssocTypeBindings = assocTypeBindingList.isNotEmpty() + val hasLifetimes = renderLifetimes && lifetimeArguments.isNotEmpty() + val hasTypeReferences = typeArguments.isNotEmpty() + val hasConstArguments = constArguments.isNotEmpty() + val hasAssocTypeBindings = assocTypeBindings.isNotEmpty() - if (hasLifetimes || hasTypeReferences || hasAssocTypeBindings) { + if (hasLifetimes || hasTypeReferences || hasConstArguments || hasAssocTypeBindings) { append("<") if (hasLifetimes) { - lifetimeList.joinToWithBuffer(this, ", ") { it.append(referenceName) } - if (hasTypeReferences || hasAssocTypeBindings) { + lifetimeArguments.joinToWithBuffer(this, ", ") { it.append(referenceName) } + if (hasTypeReferences || hasConstArguments || hasAssocTypeBindings) { append(", ") } } if (hasTypeReferences) { - typeReferenceList.joinToWithBuffer(this, ", ") { it.appendTypeReference(this, subst, renderLifetimes) } + typeArguments.joinToWithBuffer(this, ", ") { it.appendTypeReference(this, subst, renderLifetimes) } + if (hasConstArguments || hasAssocTypeBindings) { + append(", ") + } + } + if (hasConstArguments) { + constArguments.joinToWithBuffer(this, ", ") { it.appendConstExpr(this, subst) } if (hasAssocTypeBindings) { append(", ") } } - assocTypeBindingList.joinToWithBuffer(this, ", ") { sb -> + assocTypeBindings.joinToWithBuffer(this, ", ") { sb -> sb.append(referenceName) sb.append("=") typeReference?.let { sb.appendTypeReference(it, subst, renderLifetimes) } @@ -186,6 +198,14 @@ private fun StringBuilder.appendPath(path: RsPath, subst: Substitution, renderLi } } +private fun StringBuilder.appendConstExpr(expr: RsExpr, subst: Substitution) { + when (val const = expr.evaluate().substitute(subst)) { // may trigger resolve + is CtValue -> append(const) + is CtConstParameter -> append("{ $const }") + else -> append("{}") + } +} + private fun StringBuilder.appendRetType(retType: RsRetType?, subst: Substitution, renderLifetimes: Boolean) { val retTypeRef = retType?.typeReference if (retTypeRef != null) { diff --git a/src/main/kotlin/org/rust/ide/presentation/TypeRendering.kt b/src/main/kotlin/org/rust/ide/presentation/TypeRendering.kt index b29364a400d..7266bf3ff13 100644 --- a/src/main/kotlin/org/rust/ide/presentation/TypeRendering.kt +++ b/src/main/kotlin/org/rust/ide/presentation/TypeRendering.kt @@ -9,6 +9,10 @@ import org.jetbrains.annotations.TestOnly import org.rust.lang.core.psi.RsTraitItem import org.rust.lang.core.psi.ext.* import org.rust.lang.core.types.BoundElement +import org.rust.lang.core.types.consts.Const +import org.rust.lang.core.types.consts.CtConstParameter +import org.rust.lang.core.types.consts.CtUnknown +import org.rust.lang.core.types.consts.CtValue import org.rust.lang.core.types.regions.ReEarlyBound import org.rust.lang.core.types.regions.ReStatic import org.rust.lang.core.types.regions.ReUnknown @@ -49,6 +53,7 @@ private data class TypeRenderer( val unknown: String = "", val anonymous: String = "", val unknownLifetime: String = "'", + val unknownConst: String = "", val integer: String = "{integer}", val float: String = "{float}", val includeTypeArguments: Boolean = true, @@ -85,7 +90,7 @@ private data class TypeRenderer( is TySlice -> "[${render(ty.elementType)}]" is TyTuple -> ty.types.joinToString(", ", "(", ")", transform = render) - is TyArray -> "[${render(ty.base)}; ${ty.size ?: unknown}]" + is TyArray -> "[${render(ty.base)}; ${render(ty.const)}]" is TyReference -> buildString { append('&') if (includeLifetimeArguments && (ty.region is ReEarlyBound || ty.region is ReStatic)) { @@ -139,6 +144,13 @@ private data class TypeRenderer( private fun render(region: Region): String = if (region == ReUnknown) unknownLifetime else region.toString() + private fun render(const: Const, wrapParameterInBraces: Boolean = false): String = + when (const) { + is CtValue -> const.toString() + is CtConstParameter -> if (wrapParameterInBraces) "{ $const }" else const.toString() + else -> unknownConst + } + private fun formatFnLike(fnType: String, paramTypes: List, retType: Ty, render: (Ty) -> String): String = buildString { paramTypes.joinTo(this, ", ", "$fnType(", ")", transform = render) @@ -169,12 +181,9 @@ private data class TypeRenderer( private fun formatGenerics(adt: TyAdt, render: (Ty) -> String): String { val typeArgumentNames = adt.typeArguments.map(render) - val regionArgumentNames = if (includeLifetimeArguments) { - adt.regionArguments.map { render(it) } - } else { - emptyList() - } - val generics = regionArgumentNames + typeArgumentNames + val regionArgumentNames = if (includeLifetimeArguments) adt.regionArguments.map { render(it) } else emptyList() + val constArgumentNames = adt.constArguments.map { render(it, wrapParameterInBraces = true) } + val generics = regionArgumentNames + typeArgumentNames + constArgumentNames return if (generics.isEmpty()) "" else generics.joinToString(", ", "<", ">") } @@ -218,7 +227,8 @@ private data class TypeRenderer( } else { emptyList() } - return regionSubst + tySubst + val constSubst = boundElement.element.constParameters.map { render(boundElement.subst[it] ?: CtUnknown) } + return regionSubst + tySubst + constSubst } companion object { @@ -229,6 +239,7 @@ private data class TypeRenderer( unknown = "_", anonymous = "_", unknownLifetime = "'_", + unknownConst = "{}", integer = "_", float = "_" ) diff --git a/src/main/kotlin/org/rust/ide/utils/BooleanExprSimplifier.kt b/src/main/kotlin/org/rust/ide/utils/BooleanExprSimplifier.kt index 94bbbfd3fe9..4d9168d14cd 100644 --- a/src/main/kotlin/org/rust/ide/utils/BooleanExprSimplifier.kt +++ b/src/main/kotlin/org/rust/ide/utils/BooleanExprSimplifier.kt @@ -8,9 +8,9 @@ package org.rust.ide.utils import com.intellij.openapi.project.Project import org.rust.lang.core.psi.* import org.rust.lang.core.psi.ext.* +import org.rust.lang.core.types.consts.asBool import org.rust.lang.core.types.ty.TyBool -import org.rust.lang.utils.evaluation.ExprValue -import org.rust.lang.utils.evaluation.RsConstExprEvaluator +import org.rust.lang.utils.evaluation.evaluate import org.rust.lang.utils.negate class BooleanExprSimplifier(val project: Project) { @@ -109,7 +109,6 @@ class BooleanExprSimplifier(val project: Project) { private fun canBeEvaluated(expr: RsExpr): Boolean = eval(expr) != null - private fun eval(expr: RsExpr): Boolean? = - (RsConstExprEvaluator.evaluate(expr, TyBool, null) as? ExprValue.Bool)?.value + private fun eval(expr: RsExpr): Boolean? = expr.evaluate(TyBool, resolver = null).asBool() } } diff --git a/src/main/kotlin/org/rust/lang/core/completion/LookupElements.kt b/src/main/kotlin/org/rust/lang/core/completion/LookupElements.kt index 1ebf23a8eb3..4e3c2be844b 100644 --- a/src/main/kotlin/org/rust/lang/core/completion/LookupElements.kt +++ b/src/main/kotlin/org/rust/lang/core/completion/LookupElements.kt @@ -123,7 +123,9 @@ fun createLookupElement( private fun RsInferenceContext.getSubstitution(scopeEntry: ScopeEntry): Substitution = when (scopeEntry) { is AssocItemScopeEntryBase<*> -> - instantiateMethodOwnerSubstitution(scopeEntry).mapTypeValues { (_, v) -> resolveTypeVarsIfPossible(v) } + instantiateMethodOwnerSubstitution(scopeEntry) + .mapTypeValues { (_, v) -> resolveTypeVarsIfPossible(v) } + .mapConstValues { (_, v) -> resolveTypeVarsIfPossible(v) } is FieldResolveVariant -> scopeEntry.selfTy.typeParameterValues else -> diff --git a/src/main/kotlin/org/rust/lang/core/psi/ext/RsArrayType.kt b/src/main/kotlin/org/rust/lang/core/psi/ext/RsArrayType.kt index 0fb28a7b2db..98036ae29ed 100644 --- a/src/main/kotlin/org/rust/lang/core/psi/ext/RsArrayType.kt +++ b/src/main/kotlin/org/rust/lang/core/psi/ext/RsArrayType.kt @@ -6,14 +6,10 @@ package org.rust.lang.core.psi.ext import org.rust.lang.core.psi.RsArrayType +import org.rust.lang.core.types.consts.asLong import org.rust.lang.core.types.ty.TyInteger -import org.rust.lang.utils.evaluation.ExprValue -import org.rust.lang.utils.evaluation.RsConstExprEvaluator +import org.rust.lang.utils.evaluation.evaluate val RsArrayType.isSlice: Boolean get() = greenStub?.isSlice ?: (expr == null) -val RsArrayType.arraySize: Long? - get() { - val expr = expr ?: return null - return (RsConstExprEvaluator.evaluate(expr, TyInteger.USize) as? ExprValue.Integer)?.value - } +val RsArrayType.arraySize: Long? get() = expr?.evaluate(TyInteger.USize)?.asLong() diff --git a/src/main/kotlin/org/rust/lang/core/psi/ext/RsBlock.kt b/src/main/kotlin/org/rust/lang/core/psi/ext/RsBlock.kt index 5aa8be34d4e..15c737f9f35 100644 --- a/src/main/kotlin/org/rust/lang/core/psi/ext/RsBlock.kt +++ b/src/main/kotlin/org/rust/lang/core/psi/ext/RsBlock.kt @@ -37,9 +37,18 @@ val RsBlock.expandedStmtsAndTailExpr: Pair, RsExpr?> private val RsBlock.stmtsAndMacros: Sequence get() { - val parentItem = contextStrict() val stub = greenStub - return if (stub != null && parentItem is RsConstant && parentItem.isConst) { + + fun isConstant(): Boolean { + val parentItem = contextStrict() + return parentItem is RsConstant && parentItem.isConst + } + + fun isConstExpr(): Boolean { + return contextStrict() != null + } + + return if (stub != null && (isConstant() || isConstExpr())) { stub.childrenStubs.asSequence().map { it.psi } } else { childrenWithLeaves diff --git a/src/main/kotlin/org/rust/lang/core/psi/ext/RsBlockExpr.kt b/src/main/kotlin/org/rust/lang/core/psi/ext/RsBlockExpr.kt index a15fd8c221f..a57e6fd267f 100644 --- a/src/main/kotlin/org/rust/lang/core/psi/ext/RsBlockExpr.kt +++ b/src/main/kotlin/org/rust/lang/core/psi/ext/RsBlockExpr.kt @@ -7,12 +7,13 @@ package org.rust.lang.core.psi.ext import org.rust.lang.core.psi.RsBlockExpr import org.rust.lang.core.psi.RsElementTypes +import org.rust.lang.core.stubs.RsBlockExprStub val RsBlockExpr.isUnsafe: Boolean - get() = node.findChildByType(RsElementTypes.UNSAFE) != null + get() = (greenStub as? RsBlockExprStub)?.isUnsafe ?: (node.findChildByType(RsElementTypes.UNSAFE) != null) val RsBlockExpr.isAsync: Boolean - get() = node.findChildByType(RsElementTypes.ASYNC) != null + get() = (greenStub as? RsBlockExprStub)?.isAsync ?: (node.findChildByType(RsElementTypes.ASYNC) != null) val RsBlockExpr.isTry: Boolean - get() = node.findChildByType(RsElementTypes.TRY) != null + get() = (greenStub as? RsBlockExprStub)?.isTry ?: (node.findChildByType(RsElementTypes.TRY) != null) diff --git a/src/main/kotlin/org/rust/lang/core/psi/ext/RsTraitItem.kt b/src/main/kotlin/org/rust/lang/core/psi/ext/RsTraitItem.kt index a2061714f7b..c3a1eed6257 100644 --- a/src/main/kotlin/org/rust/lang/core/psi/ext/RsTraitItem.kt +++ b/src/main/kotlin/org/rust/lang/core/psi/ext/RsTraitItem.kt @@ -20,6 +20,7 @@ import org.rust.lang.core.resolve.KNOWN_DERIVABLE_TRAITS import org.rust.lang.core.resolve.knownItems import org.rust.lang.core.stubs.RsTraitItemStub import org.rust.lang.core.types.* +import org.rust.lang.core.types.consts.CtConstParameter import org.rust.lang.core.types.infer.substitute import org.rust.lang.core.types.regions.ReEarlyBound import org.rust.lang.core.types.ty.Ty @@ -120,7 +121,11 @@ private fun defaultSubstitution(item: RsTraitItem): Substitution { val parameter = ReEarlyBound(it) parameter to parameter } - return Substitution(typeSubst, regionSubst) + val constSubst = item.constParameters.associate { + val parameter = CtConstParameter(it) + parameter to parameter + } + return Substitution(typeSubst, regionSubst, constSubst) } abstract class RsTraitItemImplMixin : RsStubbedNamedElementImpl, RsTraitItem { diff --git a/src/main/kotlin/org/rust/lang/core/resolve/ImplLookup.kt b/src/main/kotlin/org/rust/lang/core/resolve/ImplLookup.kt index 1e495b45d1c..842c31c0294 100644 --- a/src/main/kotlin/org/rust/lang/core/resolve/ImplLookup.kt +++ b/src/main/kotlin/org/rust/lang/core/resolve/ImplLookup.kt @@ -13,6 +13,9 @@ import org.rust.lang.core.psi.* import org.rust.lang.core.psi.ext.* import org.rust.lang.core.resolve.indexes.RsImplIndex import org.rust.lang.core.types.* +import org.rust.lang.core.types.consts.CtConstParameter +import org.rust.lang.core.types.consts.CtInferVar +import org.rust.lang.core.types.consts.FreshCtInferVar import org.rust.lang.core.types.infer.* import org.rust.lang.core.types.ty.* import org.rust.lang.core.types.ty.Mutability.IMMUTABLE @@ -443,8 +446,11 @@ class ImplLookup( private fun findExplicitImpls(selfTy: Ty, processor: RsProcessor): Boolean { return RsImplIndex.findPotentialImpls(project, selfTy) { cachedImpl -> - val (type, generics) = cachedImpl.typeAndGenerics ?: return@findPotentialImpls false - val subst = generics.associateWith { ctx.typeVarForParam(it) }.toTypeSubst() + val (type, generics, constGenerics) = cachedImpl.typeAndGenerics ?: return@findPotentialImpls false + val subst = Substitution( + typeSubst = generics.associateWith { ctx.typeVarForParam(it) }, + constSubst = constGenerics.associateWith { ctx.constVarForParam(it) } + ) // TODO: take into account the lifetimes (?) val formalSelfTy = type.substitute(subst) val isAppropriateImpl = ctx.canCombineTypes(formalSelfTy, selfTy) && @@ -538,18 +544,23 @@ class ImplLookup( * ` S == S` */ private fun > freshen(ty: T): T { + val tyMap = hashMapOf() + val constMap = hashMapOf() + var counter = 0 - val map = HashMap() - - return ty.foldTyInferWith { - map.getOrPut(it) { - when (it) { - is TyInfer.TyVar -> FreshTyInfer.TyVar(counter++) - is TyInfer.IntVar -> FreshTyInfer.IntVar(counter++) - is TyInfer.FloatVar -> FreshTyInfer.FloatVar(counter++) + return ty + .foldTyInferWith { + tyMap.getOrPut(it) { + when (it) { + is TyInfer.TyVar -> FreshTyInfer.TyVar(counter++) + is TyInfer.IntVar -> FreshTyInfer.IntVar(counter++) + is TyInfer.FloatVar -> FreshTyInfer.FloatVar(counter++) + } } } - } + .foldCtInferWith { + constMap.getOrPut(it) { FreshCtInferVar(counter++) } + } } private fun canEvaluateObligations(ref: TraitRef, candidate: SelectionCandidate, recursionDepth: Int): Boolean { @@ -599,7 +610,10 @@ class ImplLookup( ?.let { add(SelectionCandidate.TraitObject) } } getHardcodedImpls(ref.selfTy).filter { be -> - be.element == element && ctx.probe { ctx.combinePairs(be.subst.zipTypeValues(ref.trait.subst)).isOk } + be.element == element && ctx.probe { + ctx.combineTypePairs(be.subst.zipTypeValues(ref.trait.subst)).isOk && + ctx.combineConstPairs(be.subst.zipConstValues(ref.trait.subst)).isOk + } }.forEach { add(SelectionCandidate.HardcodedImpl) } } } @@ -615,9 +629,9 @@ class ImplLookup( private fun RsCachedImplItem.trySelectCandidate(ref: TraitRef): SelectionCandidate? { val formalTraitRef = implementedTrait ?: return null if (formalTraitRef.element != ref.trait.element) return null - val (formalSelfTy, generics) = typeAndGenerics ?: return null + val (formalSelfTy, generics, constGenerics) = typeAndGenerics ?: return null val (_, implTraitRef) = - prepareSubstAndTraitRefRaw(ctx, generics, formalSelfTy, formalTraitRef, ref.selfTy) + prepareSubstAndTraitRefRaw(ctx, generics, constGenerics, formalSelfTy, formalTraitRef, ref.selfTy) if (!ctx.probe { ctx.combineTraitRefs(implTraitRef, ref) }) return null return SelectionCandidate.Impl(impl, formalSelfTy, formalTraitRef) } @@ -657,7 +671,9 @@ class ImplLookup( val (subst, preparedRef) = candidate.prepareSubstAndTraitRef(ctx, ref.selfTy) ctx.combineTraitRefs(ref, preparedRef) // pre-resolve type vars to simplify caching of already inferred obligation on fulfillment - val candidateSubst = subst.mapTypeValues { (_, v) -> ctx.resolveTypeVarsIfPossible(v) } + + val candidateSubst = subst + .mapTypeValues { (_, v) -> ctx.resolveTypeVarsIfPossible(v) } + .mapConstValues { (_, v) -> ctx.resolveTypeVarsIfPossible(v) } + mapOf(TyTypeParameter.self() to ref.selfTy).toTypeSubst() val obligations = ctx.instantiateBounds(candidate.impl.bounds, candidateSubst, newRecDepth).toList() Selection(candidate.impl, obligations, candidateSubst) @@ -700,7 +716,10 @@ class ImplLookup( } is SelectionCandidate.HardcodedImpl -> { val impl = getHardcodedImpls(ref.selfTy).first { be -> - be.element == ref.trait.element && ctx.probe { ctx.combinePairs(be.subst.zipTypeValues(ref.trait.subst)).isOk } + be.element == ref.trait.element && ctx.probe { + ctx.combineTypePairs(be.subst.zipTypeValues(ref.trait.subst)).isOk && + ctx.combineConstPairs(be.subst.zipConstValues(ref.trait.subst)).isOk + } } ctx.combineBoundElements(impl, ref.trait) val obligations = getHardcodedImplPredicates(ref.selfTy, impl).map { Obligation(newRecDepth, it) } @@ -950,7 +969,7 @@ private sealed class SelectionCandidate { val formalTrait: BoundElement ) : SelectionCandidate() { fun prepareSubstAndTraitRef(ctx: RsInferenceContext, selfTy: Ty): Pair = - prepareSubstAndTraitRefRaw(ctx, impl.generics, formalSelfTy, formalTrait, selfTy) + prepareSubstAndTraitRefRaw(ctx, impl.generics, impl.constGenerics, formalSelfTy, formalTrait, selfTy) } data class DerivedTrait(val item: RsTraitItem) : SelectionCandidate() @@ -966,11 +985,15 @@ private sealed class SelectionCandidate { private fun prepareSubstAndTraitRefRaw( ctx: RsInferenceContext, generics: List, + constGenerics: List, formalSelfTy: Ty, formalTrait: BoundElement, selfTy: Ty ): Pair { - val subst = generics.associateWith { ctx.typeVarForParam(it) }.toTypeSubst() + val subst = Substitution( + typeSubst = generics.associateWith { ctx.typeVarForParam(it) }, + constSubst = constGenerics.associateWith { ctx.constVarForParam(it) } + ) val boundSubst = formalTrait.substitute(subst).subst.mapTypeValues { (k, v) -> if (k == v && k.parameter is TyTypeParameter.Named) { // Default type parameter values `trait Tr {}` diff --git a/src/main/kotlin/org/rust/lang/core/resolve/NameResolution.kt b/src/main/kotlin/org/rust/lang/core/resolve/NameResolution.kt index 828a6cde76e..9af82622f7c 100644 --- a/src/main/kotlin/org/rust/lang/core/resolve/NameResolution.kt +++ b/src/main/kotlin/org/rust/lang/core/resolve/NameResolution.kt @@ -48,6 +48,8 @@ import org.rust.lang.core.resolve.indexes.RsMacroIndex import org.rust.lang.core.resolve.ref.* import org.rust.lang.core.stubs.index.RsNamedElementIndex import org.rust.lang.core.types.* +import org.rust.lang.core.types.consts.CtInferVar +import org.rust.lang.core.types.infer.foldCtConstParameterWith import org.rust.lang.core.types.infer.foldTyTypeParameterWith import org.rust.lang.core.types.infer.substitute import org.rust.lang.core.types.ty.* @@ -419,7 +421,12 @@ private fun processQualifiedPathResolveVariants( // it means that all possible `TyInfer` has already substituted (with `_`) subst } else { - subst.mapTypeValues { (_, v) -> v.foldTyTypeParameterWith { TyInfer.TyVar(it) } } + subst + .mapTypeValues { (_, v) -> + v.foldTyTypeParameterWith { TyInfer.TyVar(it) } + .foldCtConstParameterWith { CtInferVar(it) } + } + .mapConstValues { (_, v) -> v.foldCtConstParameterWith { CtInferVar(it) } } } base.declaredType.substitute(realSubst) } @@ -429,7 +436,11 @@ private fun processQualifiedPathResolveVariants( val restrictedTraits = if (Namespace.Types in ns && base is RsImplItem && qualifier.hasCself) { NameResolutionTestmarks.selfRelatedTypeSpecialCase.hit() base.implementedTrait?.flattenHierarchy - ?.map { it.foldTyTypeParameterWith { TyInfer.TyVar(it) } } + ?.map { value -> + value + .foldTyTypeParameterWith { TyInfer.TyVar(it) } + .foldCtConstParameterWith { CtInferVar(it) } + } } else { null } @@ -451,7 +462,11 @@ private fun processExplicitTypeQualifiedPathResolveVariants( // TODO this is a hack to fix completion test `test associated type in explicit UFCS form`. // Looks like we should use getOriginalOrSelf during resolve ?.let { BoundElement(CompletionUtil.getOriginalOrSelf(it.element), it.subst) } - ?.let { it.foldTyTypeParameterWith { TyInfer.TyVar(it) } } + ?.let { value -> + value + .foldTyTypeParameterWith { TyInfer.TyVar(it) } + .foldCtConstParameterWith { CtInferVar(it) } + } val type = typeQual.typeReference.type return processTypeQualifiedPathResolveVariants(lookup, path, processor, ns, type, trait?.let { listOf(it) }) } @@ -546,6 +561,7 @@ private fun processTypeQualifiedPathResolveVariants( val implementedTrait = e.source.implementedTrait ?.foldTyTypeParameterWith { TyInfer.TyVar(it) } + ?.foldCtConstParameterWith { CtInferVar(it) } ?: return processor(e) val isAppropriateTrait = restrictedTraits.any { diff --git a/src/main/kotlin/org/rust/lang/core/resolve/RsCachedImplItem.kt b/src/main/kotlin/org/rust/lang/core/resolve/RsCachedImplItem.kt index bd07dffee26..9f2e119cc9f 100644 --- a/src/main/kotlin/org/rust/lang/core/resolve/RsCachedImplItem.kt +++ b/src/main/kotlin/org/rust/lang/core/resolve/RsCachedImplItem.kt @@ -17,6 +17,8 @@ import org.rust.lang.core.psi.isValidProjectMember import org.rust.lang.core.resolve.ref.ResolveCacheDependency import org.rust.lang.core.resolve.ref.RsResolveCache import org.rust.lang.core.types.BoundElement +import org.rust.lang.core.types.consts.CtConstParameter +import org.rust.lang.core.types.infer.constGenerics import org.rust.lang.core.types.infer.generics import org.rust.lang.core.types.ty.Ty import org.rust.lang.core.types.ty.TyTypeParameter @@ -36,8 +38,8 @@ class RsCachedImplItem( val isInherent: Boolean get() = traitRef == null val implementedTrait: BoundElement? by lazy(PUBLICATION) { traitRef?.resolveToBoundTrait() } - val typeAndGenerics: Pair>? by lazy(PUBLICATION) { - impl.typeReference?.type?.let { it to impl.generics } + val typeAndGenerics: Triple, List>? by lazy(PUBLICATION) { + impl.typeReference?.type?.let { Triple(it, impl.generics, impl.constGenerics) } } /** For `impl T for Foo` returns union of impl members and trait `T` members that are not overriden by the impl */ diff --git a/src/main/kotlin/org/rust/lang/core/resolve/ref/RsPathReferenceImpl.kt b/src/main/kotlin/org/rust/lang/core/resolve/ref/RsPathReferenceImpl.kt index 6177d913fa4..02de7884a3c 100644 --- a/src/main/kotlin/org/rust/lang/core/resolve/ref/RsPathReferenceImpl.kt +++ b/src/main/kotlin/org/rust/lang/core/resolve/ref/RsPathReferenceImpl.kt @@ -12,14 +12,17 @@ import org.rust.lang.core.psi.ext.* import org.rust.lang.core.resolve.* import org.rust.lang.core.types.BoundElement import org.rust.lang.core.types.Substitution +import org.rust.lang.core.types.consts.CtConstParameter +import org.rust.lang.core.types.consts.CtUnknown import org.rust.lang.core.types.infer.foldTyInferWith import org.rust.lang.core.types.infer.resolve import org.rust.lang.core.types.infer.substitute import org.rust.lang.core.types.inference import org.rust.lang.core.types.regions.ReEarlyBound -import org.rust.lang.core.types.regions.Region import org.rust.lang.core.types.ty.* import org.rust.lang.core.types.type +import org.rust.lang.utils.evaluation.PathExprResolver +import org.rust.lang.utils.evaluation.evaluate import org.rust.stdext.buildMap import org.rust.stdext.intersects @@ -122,7 +125,8 @@ fun resolvePath(path: RsPath, lookup: ImplLookup? = null): List instantiatePathGenerics( path: RsPath, - resolved: BoundElement + resolved: BoundElement, + resolver: PathExprResolver? = PathExprResolver.default ): BoundElement { val (element, subst) = resolved.downcast() ?: return resolved @@ -190,7 +194,19 @@ fun instantiatePathGenerics( val regionParameters = element.lifetimeParameters.map { ReEarlyBound(it) } val regionArguments = path.typeArgumentList?.lifetimeList?.map { it.resolve() } val regionSubst = regionParameters.zip(regionArguments ?: regionParameters).toMap() - val newSubst = Substitution(typeSubst, regionSubst) + + val constParameters = element.constParameters.map { CtConstParameter(it) } + val constArguments = path.typeArgumentList?.exprList?.withIndex()?.map { (i, expr) -> + val expectedTy = constParameters.getOrNull(i)?.parameter?.typeReference?.type ?: TyUnknown + expr.evaluate(expectedTy, resolver) + } + val constSubst = constParameters.withIndex().associate { (i, param) -> + val value = constArguments?.getOrNull(i) + ?: if (areOptionalArgs && constArguments == null) param else CtUnknown + param to value + } + + val newSubst = Substitution(typeSubst, regionSubst, constSubst) return BoundElement(resolved.element, subst + newSubst, assocTypes) } diff --git a/src/main/kotlin/org/rust/lang/core/stubs/StubImplementations.kt b/src/main/kotlin/org/rust/lang/core/stubs/StubImplementations.kt index dec14a685fa..d9dca65ed9a 100644 --- a/src/main/kotlin/org/rust/lang/core/stubs/StubImplementations.kt +++ b/src/main/kotlin/org/rust/lang/core/stubs/StubImplementations.kt @@ -49,7 +49,7 @@ class RsFileStub : PsiFileStubImpl { override fun getType() = Type object Type : IStubFileElementType(RsLanguage) { - private const val STUB_VERSION = 193 + private const val STUB_VERSION = 194 // Bump this number if Stub structure changes override fun getStubVersion(): Int = RustParserDefinition.PARSER_VERSION + STUB_VERSION @@ -255,7 +255,7 @@ fun factory(name: String): RsStubElementType<*, *> = when (name) { "ARRAY_EXPR" -> RsExprStubType("ARRAY_EXPR", ::RsArrayExprImpl) "BINARY_EXPR" -> RsExprStubType("BINARY_EXPR", ::RsBinaryExprImpl) - "BLOCK_EXPR" -> RsExprStubType("BLOCK_EXPR", ::RsBlockExprImpl) + "BLOCK_EXPR" -> RsBlockExprStub.Type "BREAK_EXPR" -> RsExprStubType("BREAK_EXPR", ::RsBreakExprImpl) "CALL_EXPR" -> RsExprStubType("CALL_EXPR", ::RsCallExprImpl) "CAST_EXPR" -> RsExprStubType("CAST_EXPR", ::RsCastExprImpl) @@ -1393,6 +1393,43 @@ class RsExprStubType( override fun shouldCreateStub(node: ASTNode): Boolean = shouldCreateExprStub(node) } +class RsBlockExprStub( + parent: StubElement<*>?, elementType: IStubElementType<*, *>, + private val flags: Int +) : RsPlaceholderStub(parent, elementType) { + val isUnsafe: Boolean get() = BitUtil.isSet(flags, UNSAFE_MASK) + val isAsync: Boolean get() = BitUtil.isSet(flags, ASYNC_MASK) + val isTry: Boolean get() = BitUtil.isSet(flags, TRY_MASK) + + object Type : RsStubElementType("BLOCK_EXPR") { + + override fun shouldCreateStub(node: ASTNode): Boolean = shouldCreateExprStub(node) + + override fun serialize(stub: RsBlockExprStub, dataStream: StubOutputStream) { + dataStream.writeInt(stub.flags) + } + + override fun deserialize(dataStream: StubInputStream, parentStub: StubElement<*>?): RsBlockExprStub = + RsBlockExprStub(parentStub, this, dataStream.readInt()) + + override fun createStub(psi: RsBlockExpr, parentStub: StubElement<*>?): RsBlockExprStub { + var flags = 0 + flags = BitUtil.set(flags, UNSAFE_MASK, psi.isUnsafe) + flags = BitUtil.set(flags, ASYNC_MASK, psi.isAsync) + flags = BitUtil.set(flags, TRY_MASK, psi.isTry) + return RsBlockExprStub(parentStub, this, flags) + } + + override fun createPsi(stub: RsBlockExprStub): RsBlockExpr = RsBlockExprImpl(stub, this) + } + + companion object { + private val UNSAFE_MASK: Int = makeBitMask(0) + private val ASYNC_MASK: Int = makeBitMask(1) + private val TRY_MASK: Int = makeBitMask(2) + } +} + class RsLitExprStub( parent: StubElement<*>?, elementType: IStubElementType<*, *>, val kind: RsStubLiteralKind? diff --git a/src/main/kotlin/org/rust/lang/core/types/BoundElement.kt b/src/main/kotlin/org/rust/lang/core/types/BoundElement.kt index afdac89e09a..731001b1928 100644 --- a/src/main/kotlin/org/rust/lang/core/types/BoundElement.kt +++ b/src/main/kotlin/org/rust/lang/core/types/BoundElement.kt @@ -10,7 +10,10 @@ import com.intellij.psi.ResolveResult import org.rust.lang.core.psi.RsTypeAlias import org.rust.lang.core.psi.ext.RsElement import org.rust.lang.core.psi.ext.RsGenericDeclaration +import org.rust.lang.core.psi.ext.constParameters import org.rust.lang.core.psi.ext.typeParameters +import org.rust.lang.core.types.consts.Const +import org.rust.lang.core.types.consts.CtConstParameter import org.rust.lang.core.types.infer.TypeFoldable import org.rust.lang.core.types.infer.TypeFolder import org.rust.lang.core.types.infer.TypeVisitor @@ -51,10 +54,15 @@ data class BoundElement( assoc.mapValues { (_, value) -> value.foldWith(folder) } ) - override fun superVisitWith(visitor: TypeVisitor): Boolean = assoc.values.any { visitor.visitTy(it) } || - subst.types.any { it.visitWith(visitor) } || subst.regions.any { it.visitWith(visitor) } - + override fun superVisitWith(visitor: TypeVisitor): Boolean = + assoc.values.any { visitor.visitTy(it) } || + subst.types.any { it.visitWith(visitor) } || + subst.regions.any { it.visitWith(visitor) } || + subst.consts.any { it.visitWith(visitor) } } val BoundElement.positionalTypeArguments: List get() = element.typeParameters.map { subst[it] ?: TyTypeParameter.named(it) } + +val BoundElement.positionalConstArguments: List + get() = element.constParameters.map { subst[it] ?: CtConstParameter(it) } diff --git a/src/main/kotlin/org/rust/lang/core/types/Extensions.kt b/src/main/kotlin/org/rust/lang/core/types/Extensions.kt index 3b8b73acb01..63b7d4204e2 100644 --- a/src/main/kotlin/org/rust/lang/core/types/Extensions.kt +++ b/src/main/kotlin/org/rust/lang/core/types/Extensions.kt @@ -42,9 +42,10 @@ private fun RsInferenceContextOwner.createResult(value: T): Result { // CachedValueProvider.Result can accept a ModificationTracker as a dependency, so the // cached value will be invalidated if the modification counter is incremented. - this is RsModificationTrackerOwner -> Result.create(value, structureModificationTracker, modificationTracker) - - else -> Result.create(value, structureModificationTracker) + else -> { + val modificationTracker = contextOrSelf()?.modificationTracker + Result.create(value, listOfNotNull(structureModificationTracker, modificationTracker)) + } } } diff --git a/src/main/kotlin/org/rust/lang/core/types/Kind.kt b/src/main/kotlin/org/rust/lang/core/types/Kind.kt index 3ca421ab7f6..d40d0f9c474 100644 --- a/src/main/kotlin/org/rust/lang/core/types/Kind.kt +++ b/src/main/kotlin/org/rust/lang/core/types/Kind.kt @@ -11,9 +11,12 @@ const val HAS_TY_INFER_MASK: TypeFlags = 1 const val HAS_TY_TYPE_PARAMETER_MASK: TypeFlags = 2 const val HAS_TY_PROJECTION_MASK: TypeFlags = 4 const val HAS_RE_EARLY_BOUND_MASK: TypeFlags = 8 +const val HAS_CT_INFER_MASK: TypeFlags = 16 +const val HAS_CT_PARAMETER_MASK: TypeFlags = 32 +const val HAS_CT_UNEVALUATED_MASK: TypeFlags = 64 /** - * An entity in the Rust type system, which can be one of several kinds (only types and lifetimes for now). + * An entity in the Rust type system, which can be one of several kinds (only types, lifetimes and constants for now). */ interface Kind { val flags: TypeFlags diff --git a/src/main/kotlin/org/rust/lang/core/types/Substitution.kt b/src/main/kotlin/org/rust/lang/core/types/Substitution.kt index bf649ed2b66..d7559aaf80b 100644 --- a/src/main/kotlin/org/rust/lang/core/types/Substitution.kt +++ b/src/main/kotlin/org/rust/lang/core/types/Substitution.kt @@ -5,8 +5,11 @@ package org.rust.lang.core.types +import org.rust.lang.core.psi.RsConstParameter import org.rust.lang.core.psi.RsLifetimeParameter import org.rust.lang.core.psi.RsTypeParameter +import org.rust.lang.core.types.consts.Const +import org.rust.lang.core.types.consts.CtConstParameter import org.rust.lang.core.types.infer.TypeFolder import org.rust.lang.core.types.infer.substitute import org.rust.lang.core.types.regions.ReEarlyBound @@ -18,19 +21,29 @@ import org.rust.stdext.zipValues open class Substitution( val typeSubst: Map = emptyMap(), - val regionSubst: Map = emptyMap() + val regionSubst: Map = emptyMap(), + val constSubst: Map = emptyMap() ) { val types: Collection get() = typeSubst.values val regions: Collection get() = regionSubst.values - val kinds: Collection get() = types + regions + val consts: Collection get() = constSubst.values + val kinds: Collection get() = (types as Collection) + regions + consts operator fun plus(other: Substitution): Substitution = - Substitution(mergeMaps(typeSubst, other.typeSubst), mergeMaps(regionSubst, other.regionSubst)) + Substitution( + mergeMaps(typeSubst, other.typeSubst), + mergeMaps(regionSubst, other.regionSubst), + mergeMaps(constSubst, other.constSubst) + ) + + operator fun get(key: TyTypeParameter): Ty? = typeSubst[key] + operator fun get(psi: RsTypeParameter): Ty? = typeSubst[TyTypeParameter.named(psi)] - operator fun get(key: TyTypeParameter) = typeSubst[key] - operator fun get(key: ReEarlyBound) = regionSubst[key] - operator fun get(psi: RsTypeParameter) = typeSubst[TyTypeParameter.named((psi))] - operator fun get(psi: RsLifetimeParameter) = regionSubst[ReEarlyBound((psi))] + operator fun get(key: ReEarlyBound): Region? = regionSubst[key] + operator fun get(psi: RsLifetimeParameter): Region? = regionSubst[ReEarlyBound(psi)] + + operator fun get(key: CtConstParameter): Const? = constSubst[key] + operator fun get(psi: RsConstParameter): Const? = constSubst[CtConstParameter(psi)] fun typeByName(name: String): Ty = typeSubst.entries.find { it.key.toString() == name }?.value ?: TyUnknown @@ -41,19 +54,26 @@ open class Substitution( fun substituteInValues(map: Substitution): Substitution = Substitution( typeSubst.mapValues { (_, value) -> value.substitute(map) }, - regionSubst.mapValues { (_, value) -> value.substitute(map) } + regionSubst.mapValues { (_, value) -> value.substitute(map) }, + constSubst.mapValues { (_, value) -> value.substitute(map) } ) fun foldValues(folder: TypeFolder): Substitution = Substitution( typeSubst.mapValues { (_, value) -> value.foldWith(folder) }, - regionSubst.mapValues { (_, value) -> value.foldWith(folder) } + regionSubst.mapValues { (_, value) -> value.foldWith(folder) }, + constSubst.mapValues { (_, value) -> value.foldWith(folder) } ) fun zipTypeValues(other: Substitution): List> = zipValues(typeSubst, other.typeSubst) + fun zipConstValues(other: Substitution): List> = zipValues(constSubst, other.constSubst) + fun mapTypeValues(transform: (Map.Entry) -> Ty): Substitution = - Substitution(typeSubst.mapValues(transform), regionSubst) + Substitution(typeSubst.mapValues(transform), regionSubst, constSubst) + + fun mapConstValues(transform: (Map.Entry) -> Const): Substitution = + Substitution(typeSubst, regionSubst, constSubst.mapValues(transform)) override fun equals(other: Any?): Boolean = when { this === other -> true @@ -70,7 +90,7 @@ private object EmptySubstitution : Substitution() val emptySubstitution: Substitution = EmptySubstitution -fun Map.toTypeSubst() = Substitution(typeSubst = this) +fun Map.toTypeSubst(): Substitution = Substitution(typeSubst = this) private fun mergeMaps(map1: Map, map2: Map): Map = if (map1.isEmpty() && map2.isEmpty()) emptyMap() else HashMap(map1).apply { putAll(map2) } diff --git a/src/main/kotlin/org/rust/lang/core/types/consts/Const.kt b/src/main/kotlin/org/rust/lang/core/types/consts/Const.kt new file mode 100644 index 00000000000..b806f9b7186 --- /dev/null +++ b/src/main/kotlin/org/rust/lang/core/types/consts/Const.kt @@ -0,0 +1,23 @@ +/* + * Use of this source code is governed by the MIT license that can be + * found in the LICENSE file. + */ + +package org.rust.lang.core.types.consts + +import org.rust.lang.core.types.Kind +import org.rust.lang.core.types.TypeFlags +import org.rust.lang.core.types.infer.TypeFoldable +import org.rust.lang.core.types.infer.TypeFolder +import org.rust.lang.core.types.infer.TypeVisitor + +abstract class Const(override val flags: TypeFlags = 0) : Kind, TypeFoldable { + + override fun foldWith(folder: TypeFolder): Const = folder.foldConst(this) + + override fun superFoldWith(folder: TypeFolder): Const = this + + override fun visitWith(visitor: TypeVisitor): Boolean = visitor.visitConst(this) + + override fun superVisitWith(visitor: TypeVisitor): Boolean = false +} diff --git a/src/main/kotlin/org/rust/lang/core/types/consts/CtConstParameter.kt b/src/main/kotlin/org/rust/lang/core/types/consts/CtConstParameter.kt new file mode 100644 index 00000000000..00112c03871 --- /dev/null +++ b/src/main/kotlin/org/rust/lang/core/types/consts/CtConstParameter.kt @@ -0,0 +1,20 @@ +/* + * Use of this source code is governed by the MIT license that can be + * found in the LICENSE file. + */ + +package org.rust.lang.core.types.consts + +import com.intellij.codeInsight.completion.CompletionUtil +import org.rust.lang.core.psi.RsConstParameter +import org.rust.lang.core.types.HAS_CT_PARAMETER_MASK + +class CtConstParameter(parameter: RsConstParameter) : Const(HAS_CT_PARAMETER_MASK) { + val parameter: RsConstParameter = CompletionUtil.getOriginalOrSelf(parameter) + + override fun equals(other: Any?): Boolean = other is CtConstParameter && other.parameter == parameter + + override fun hashCode(): Int = parameter.hashCode() + + override fun toString(): String = parameter.name ?: "" +} diff --git a/src/main/kotlin/org/rust/lang/core/types/consts/CtInfer.kt b/src/main/kotlin/org/rust/lang/core/types/consts/CtInfer.kt new file mode 100644 index 00000000000..4387e78beb3 --- /dev/null +++ b/src/main/kotlin/org/rust/lang/core/types/consts/CtInfer.kt @@ -0,0 +1,18 @@ +/* + * Use of this source code is governed by the MIT license that can be + * found in the LICENSE file. + */ + +package org.rust.lang.core.types.consts + +import org.rust.lang.core.types.HAS_CT_INFER_MASK +import org.rust.lang.core.types.infer.Node +import org.rust.lang.core.types.infer.NodeOrValue +import org.rust.lang.core.types.infer.VarValue + +class CtInferVar( + val origin: Const? = null, + override var parent: NodeOrValue = VarValue(null, 0) +) : Const(HAS_CT_INFER_MASK), Node + +data class FreshCtInferVar(val id: Int) : Const() diff --git a/src/main/kotlin/org/rust/lang/core/types/consts/CtUnevaluated.kt b/src/main/kotlin/org/rust/lang/core/types/consts/CtUnevaluated.kt new file mode 100644 index 00000000000..697ed829635 --- /dev/null +++ b/src/main/kotlin/org/rust/lang/core/types/consts/CtUnevaluated.kt @@ -0,0 +1,21 @@ +/* + * Use of this source code is governed by the MIT license that can be + * found in the LICENSE file. + */ + +package org.rust.lang.core.types.consts + +import org.rust.lang.core.types.HAS_CT_UNEVALUATED_MASK +import org.rust.lang.core.types.infer.TypeFolder +import org.rust.lang.core.types.infer.TypeVisitor +import org.rust.lang.utils.evaluation.ConstExpr + +data class CtUnevaluated(val expr: ConstExpr<*>) : Const(HAS_CT_UNEVALUATED_MASK or expr.flags) { + override fun superFoldWith(folder: TypeFolder): Const = + CtUnevaluated(expr.foldWith(folder)) + + override fun superVisitWith(visitor: TypeVisitor): Boolean = + expr.visitWith(visitor) + + override fun toString(): String = "" +} diff --git a/src/main/kotlin/org/rust/lang/core/types/consts/CtUnknown.kt b/src/main/kotlin/org/rust/lang/core/types/consts/CtUnknown.kt new file mode 100644 index 00000000000..218c0b668f7 --- /dev/null +++ b/src/main/kotlin/org/rust/lang/core/types/consts/CtUnknown.kt @@ -0,0 +1,10 @@ +/* + * Use of this source code is governed by the MIT license that can be + * found in the LICENSE file. + */ + +package org.rust.lang.core.types.consts + +object CtUnknown : Const() { + override fun toString(): String = "" +} diff --git a/src/main/kotlin/org/rust/lang/core/types/consts/CtValue.kt b/src/main/kotlin/org/rust/lang/core/types/consts/CtValue.kt new file mode 100644 index 00000000000..d198b15955b --- /dev/null +++ b/src/main/kotlin/org/rust/lang/core/types/consts/CtValue.kt @@ -0,0 +1,22 @@ +/* + * Use of this source code is governed by the MIT license that can be + * found in the LICENSE file. + */ + +package org.rust.lang.core.types.consts + +import org.rust.lang.utils.evaluation.ConstExpr + +fun Const.asBool(): Boolean? { + if (this !is CtValue) return null + return (expr as? ConstExpr.Value.Bool)?.value +} + +fun Const.asLong(): Long? { + if (this !is CtValue) return null + return (expr as? ConstExpr.Value.Integer)?.value +} + +data class CtValue(val expr: ConstExpr.Value<*>) : Const() { + override fun toString(): String = expr.toString() +} diff --git a/src/main/kotlin/org/rust/lang/core/types/infer/Declarations.kt b/src/main/kotlin/org/rust/lang/core/types/infer/Declarations.kt index 53b21fa00ae..f81cf585a90 100644 --- a/src/main/kotlin/org/rust/lang/core/types/infer/Declarations.kt +++ b/src/main/kotlin/org/rust/lang/core/types/infer/Declarations.kt @@ -8,12 +8,17 @@ package org.rust.lang.core.types.infer import org.rust.lang.core.psi.* import org.rust.lang.core.psi.ext.* import org.rust.lang.core.types.Substitution +import org.rust.lang.core.types.consts.Const +import org.rust.lang.core.types.consts.CtConstParameter +import org.rust.lang.core.types.consts.CtUnknown import org.rust.lang.core.types.regions.ReEarlyBound import org.rust.lang.core.types.regions.ReStatic import org.rust.lang.core.types.regions.ReUnknown import org.rust.lang.core.types.regions.Region import org.rust.lang.core.types.ty.* import org.rust.lang.core.types.type +import org.rust.lang.utils.evaluation.evaluate +import org.rust.lang.utils.evaluation.tryEvaluate // Keep in sync with TyFingerprint-create @@ -84,7 +89,8 @@ fun inferTypeReferenceType(ref: RsTypeReference, defaultTraitObjectRegion: Regio if (type.isSlice) { TySlice(componentType) } else { - TyArray(componentType, type.arraySize) + val const = type.expr?.evaluate(TyInteger.USize) ?: CtUnknown + TyArray(componentType, const) } } @@ -117,19 +123,25 @@ fun RsLifetime?.resolve(): Region { return if (resolved is RsLifetimeParameter) ReEarlyBound(resolved) else ReUnknown } -private fun TypeFoldable.substituteWithTraitObjectRegion( +private fun > TypeFoldable.substituteWithTraitObjectRegion( subst: Substitution, defaultTraitObjectRegion: Region ): T = foldWith(object : TypeFolder { override fun foldTy(ty: Ty): Ty = when { ty is TyTypeParameter -> handleTraitObject(ty) ?: ty - ty.needToSubstitute -> ty.superFoldWith(this) + ty.needsSubst -> ty.superFoldWith(this) else -> ty } override fun foldRegion(region: Region): Region = (region as? ReEarlyBound)?.let { subst[it] } ?: region + override fun foldConst(const: Const): Const = when { + const is CtConstParameter -> subst[const] ?: const + const.hasCtConstParameters -> const.superFoldWith(this) + else -> const + } + fun handleTraitObject(paramTy: TyTypeParameter): Ty? { val ty = subst[paramTy] if (ty !is TyTraitObject || ty.region !is ReUnknown) return ty @@ -141,4 +153,4 @@ private fun TypeFoldable.substituteWithTraitObjectRegion( } return TyTraitObject(ty.trait, region) } -}) +}).tryEvaluate() diff --git a/src/main/kotlin/org/rust/lang/core/types/infer/Fold.kt b/src/main/kotlin/org/rust/lang/core/types/infer/Fold.kt index 43918fa5980..1cd0931fb59 100644 --- a/src/main/kotlin/org/rust/lang/core/types/infer/Fold.kt +++ b/src/main/kotlin/org/rust/lang/core/types/infer/Fold.kt @@ -7,22 +7,33 @@ package org.rust.lang.core.types.infer import com.intellij.util.BitUtil import org.rust.lang.core.types.* +import org.rust.lang.core.types.consts.Const +import org.rust.lang.core.types.consts.CtConstParameter +import org.rust.lang.core.types.consts.CtInferVar +import org.rust.lang.core.types.consts.CtUnknown +import org.rust.lang.core.types.infer.HasTypeFlagVisitor.Companion.HAS_CT_INFER_VISITOR +import org.rust.lang.core.types.infer.HasTypeFlagVisitor.Companion.HAS_CT_PARAMETER_VISITOR +import org.rust.lang.core.types.infer.HasTypeFlagVisitor.Companion.HAS_CT_UNEVALUATED_VISITOR import org.rust.lang.core.types.infer.HasTypeFlagVisitor.Companion.HAS_RE_EARLY_BOUND_VISITOR import org.rust.lang.core.types.infer.HasTypeFlagVisitor.Companion.HAS_TY_INFER_VISITOR import org.rust.lang.core.types.infer.HasTypeFlagVisitor.Companion.HAS_TY_PROJECTION_VISITOR import org.rust.lang.core.types.infer.HasTypeFlagVisitor.Companion.HAS_TY_TYPE_PARAMETER_VISITOR import org.rust.lang.core.types.regions.ReEarlyBound +import org.rust.lang.core.types.regions.ReUnknown import org.rust.lang.core.types.regions.Region import org.rust.lang.core.types.ty.* +import org.rust.lang.utils.evaluation.tryEvaluate interface TypeFolder { fun foldTy(ty: Ty): Ty = ty fun foldRegion(region: Region): Region = region + fun foldConst(const: Const): Const = const } interface TypeVisitor { fun visitTy(ty: Ty): Boolean = false fun visitRegion(region: Region): Boolean = false + fun visitConst(const: Const): Boolean = false } /** @@ -89,6 +100,27 @@ fun TypeFoldable.foldTyTypeParameterWith(folder: (TyTypeParameter) -> Ty) } }) +/** Deeply replace any [CtInferVar] with the function [folder] */ +fun > TypeFoldable.foldCtInferWith(folder: (CtInferVar) -> Const): T = + foldWith(object : TypeFolder { + override fun foldTy(ty: Ty): Ty = if (ty.hasCtInfer) ty.superFoldWith(this) else ty + override fun foldConst(const: Const): Const { + val foldedCt = if (const is CtInferVar) folder(const) else const + return if (foldedCt.hasCtInfer) foldedCt.superFoldWith(this) else foldedCt + } + }).tryEvaluate() + +/** Deeply replace any [CtConstParameter] with the function [folder] */ +fun > TypeFoldable.foldCtConstParameterWith(folder: (CtConstParameter) -> Const): T = + foldWith(object : TypeFolder { + override fun foldTy(ty: Ty): Ty = if (ty.hasCtConstParameters) ty.superFoldWith(this) else ty + override fun foldConst(const: Const): Const = when { + const is CtConstParameter -> folder(const) + const.hasCtConstParameters -> const.superFoldWith(this) + else -> const + } + }).tryEvaluate() + /** Deeply replace any [TyProjection] with the function [folder] */ fun TypeFoldable.foldTyProjectionWith(folder: (TyProjection) -> Ty): T = foldWith(object : TypeFolder { @@ -100,31 +132,43 @@ fun TypeFoldable.foldTyProjectionWith(folder: (TyProjection) -> Ty): T = }) /** - * Deeply replace any [TyTypeParameter] by [subst] mapping. + * Deeply replace any [TyTypeParameter], [ReEarlyBound] and [CtConstParameter] by [subst] mapping. */ -fun TypeFoldable.substitute(subst: Substitution): T = +fun > TypeFoldable.substitute(subst: Substitution): T = foldWith(object : TypeFolder { override fun foldTy(ty: Ty): Ty = when { ty is TyTypeParameter -> subst[ty] ?: ty - ty.needToSubstitute -> ty.superFoldWith(this) + ty.needsSubst -> ty.superFoldWith(this) else -> ty } override fun foldRegion(region: Region): Region = (region as? ReEarlyBound)?.let { subst[it] } ?: region - }) -fun TypeFoldable.substituteOrUnknown(subst: Substitution): T = + override fun foldConst(const: Const): Const = when { + const is CtConstParameter -> subst[const] ?: const + const.hasCtConstParameters -> const.superFoldWith(this) + else -> const + } + }).tryEvaluate() + +fun > TypeFoldable.substituteOrUnknown(subst: Substitution): T = foldWith(object : TypeFolder { override fun foldTy(ty: Ty): Ty = when { ty is TyTypeParameter -> subst[ty] ?: TyUnknown - ty.needToSubstitute -> ty.superFoldWith(this) + ty.needsSubst -> ty.superFoldWith(this) else -> ty } override fun foldRegion(region: Region): Region = - (region as? ReEarlyBound)?.let { subst[it] } ?: region - }) + (region as? ReEarlyBound)?.let { subst[it] } ?: ReUnknown + + override fun foldConst(const: Const): Const = when { + const is CtConstParameter -> subst[const] ?: CtUnknown + const.hasCtConstParameters -> const.superFoldWith(this) + else -> const + } + }).tryEvaluate() fun TypeFoldable.containsTyOfClass(classes: List>): Boolean = visitWith(object : TypeVisitor { @@ -135,6 +179,15 @@ fun TypeFoldable.containsTyOfClass(classes: List>): Boolean = fun TypeFoldable.containsTyOfClass(vararg classes: Class<*>): Boolean = containsTyOfClass(classes.toList()) +fun TypeFoldable.containsConstOfClass(classes: List>): Boolean = + visitWith(object : TypeVisitor { + override fun visitTy(ty: Ty): Boolean = ty.superVisitWith(this) + override fun visitConst(const: Const): Boolean = classes.any { it.isInstance(const) } + }) + +fun TypeFoldable.containsConstOfClass(vararg classes: Class<*>): Boolean = + containsConstOfClass(classes.toList()) + fun TypeFoldable.collectInferTys(): List { val list = mutableListOf() visitInferTys { @@ -157,12 +210,16 @@ fun TypeFoldable.visitInferTys(visitor: (TyInfer) -> Boolean): Boolean { private data class HasTypeFlagVisitor(val flag: TypeFlags) : TypeVisitor { override fun visitTy(ty: Ty): Boolean = BitUtil.isSet(ty.flags, flag) override fun visitRegion(region: Region): Boolean = BitUtil.isSet(region.flags, flag) + override fun visitConst(const: Const): Boolean = BitUtil.isSet(const.flags, flag) companion object { val HAS_TY_INFER_VISITOR = HasTypeFlagVisitor(HAS_TY_INFER_MASK) val HAS_TY_TYPE_PARAMETER_VISITOR = HasTypeFlagVisitor(HAS_TY_TYPE_PARAMETER_MASK) val HAS_TY_PROJECTION_VISITOR = HasTypeFlagVisitor(HAS_TY_PROJECTION_MASK) val HAS_RE_EARLY_BOUND_VISITOR = HasTypeFlagVisitor(HAS_RE_EARLY_BOUND_MASK) + val HAS_CT_INFER_VISITOR = HasTypeFlagVisitor(HAS_CT_INFER_MASK) + val HAS_CT_PARAMETER_VISITOR = HasTypeFlagVisitor(HAS_CT_PARAMETER_MASK) + val HAS_CT_UNEVALUATED_VISITOR = HasTypeFlagVisitor(HAS_CT_UNEVALUATED_MASK) } } @@ -178,5 +235,20 @@ val TypeFoldable<*>.hasTyProjection val TypeFoldable<*>.hasReEarlyBounds get(): Boolean = visitWith(HAS_RE_EARLY_BOUND_VISITOR) -val TypeFoldable<*>.needToSubstitute - get(): Boolean = hasTyTypeParameters || hasReEarlyBounds +val TypeFoldable<*>.hasCtInfer + get(): Boolean = visitWith(HAS_CT_INFER_VISITOR) + +val TypeFoldable<*>.hasCtUnevaluated + get(): Boolean = visitWith(HAS_CT_UNEVALUATED_VISITOR) + +val TypeFoldable<*>.hasCtConstParameters + get(): Boolean = visitWith(HAS_CT_PARAMETER_VISITOR) + +val TypeFoldable<*>.needsInfer + get(): Boolean = hasTyInfer || hasCtInfer + +val TypeFoldable<*>.needsSubst + get(): Boolean = hasTyTypeParameters || hasReEarlyBounds || hasCtConstParameters + +val TypeFoldable<*>.needsEval + get(): Boolean = hasCtUnevaluated && !hasCtConstParameters diff --git a/src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt b/src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt index 14bacf02690..ae390bd54d7 100644 --- a/src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt +++ b/src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt @@ -18,6 +18,7 @@ import org.rust.lang.core.resolve.* import org.rust.lang.core.resolve.ref.MethodResolveVariant import org.rust.lang.core.resolve.ref.resolvePathRaw import org.rust.lang.core.types.* +import org.rust.lang.core.types.consts.* import org.rust.lang.core.types.regions.Region import org.rust.lang.core.types.ty.* import org.rust.lang.utils.RsDiagnostic @@ -139,12 +140,14 @@ class RsInferenceContext( private val intUnificationTable: UnificationTable = UnificationTable() private val floatUnificationTable: UnificationTable = UnificationTable() private val varUnificationTable: UnificationTable = UnificationTable() + private val constUnificationTable: UnificationTable = UnificationTable() private val projectionCache: ProjectionCache = ProjectionCache() fun startSnapshot(): Snapshot = CombinedSnapshot( intUnificationTable.startSnapshot(), floatUnificationTable.startSnapshot(), varUnificationTable.startSnapshot(), + constUnificationTable.startSnapshot(), projectionCache.startSnapshot() ) @@ -435,12 +438,12 @@ class RsInferenceContext( is TyInfer.IntVar -> when (ty2) { is TyInfer.IntVar -> intUnificationTable.unifyVarVar(ty1, ty2) is TyInteger -> intUnificationTable.unifyVarValue(ty1, ty2) - else -> return CoerceResult.Mismatch(ty1, ty2) + else -> return CoerceResult.TypeMismatch(ty1, ty2) } is TyInfer.FloatVar -> when (ty2) { is TyInfer.FloatVar -> floatUnificationTable.unifyVarVar(ty1, ty2) is TyFloat -> floatUnificationTable.unifyVarValue(ty1, ty2) - else -> return CoerceResult.Mismatch(ty1, ty2) + else -> return CoerceResult.TypeMismatch(ty1, ty2) } is TyInfer.TyVar -> error("unreachable") } @@ -461,28 +464,64 @@ class RsInferenceContext( ty1 is TyPointer && ty2 is TyPointer && ty1.mutability == ty2.mutability -> { combineTypes(ty1.referenced, ty2.referenced) } - ty1 is TyArray && ty2 is TyArray && - (ty1.size == null || ty2.size == null || ty1.size == ty2.size) -> combineTypes(ty1.base, ty2.base) + ty1 is TyArray && ty2 is TyArray && (ty1.size == null || ty2.size == null || ty1.size == ty2.size) -> + combineTypes(ty1.base, ty2.base).and { combineConsts(ty1.const, ty2.const) } ty1 is TySlice && ty2 is TySlice -> combineTypes(ty1.elementType, ty2.elementType) ty1 is TyTuple && ty2 is TyTuple && ty1.types.size == ty2.types.size -> { - combinePairs(ty1.types.zip(ty2.types)) + combineTypePairs(ty1.types.zip(ty2.types)) } ty1 is TyFunction && ty2 is TyFunction && ty1.paramTypes.size == ty2.paramTypes.size -> { - combinePairs(ty1.paramTypes.zip(ty2.paramTypes)).and { combineTypes(ty1.retType, ty2.retType) } + combineTypePairs(ty1.paramTypes.zip(ty2.paramTypes)) + .and { combineTypes(ty1.retType, ty2.retType) } } ty1 is TyAdt && ty2 is TyAdt && ty1.item == ty2.item -> { - combinePairs(ty1.typeArguments.zip(ty2.typeArguments)) + combineTypePairs(ty1.typeArguments.zip(ty2.typeArguments)) + .and { combineConstPairs(ty1.constArguments.zip(ty2.constArguments)) } } ty1 is TyTraitObject && ty2 is TyTraitObject && combineBoundElements(ty1.trait, ty2.trait) -> CoerceResult.Ok ty1 is TyAnon && ty2 is TyAnon && ty1.definition != null && ty1.definition == ty2.definition -> CoerceResult.Ok ty1 is TyNever || ty2 is TyNever -> CoerceResult.Ok - else -> CoerceResult.Mismatch(ty1, ty2) + else -> CoerceResult.TypeMismatch(ty1, ty2) } - fun combinePairs(pairs: List>): CoerceResult { + fun combineConsts(const1: Const, const2: Const): CoerceResult { + return combineConstsResolved(shallowResolve(const1), shallowResolve(const2)) + } + + private fun combineConstsResolved(const1: Const, const2: Const): CoerceResult = + when { + const1 is CtInferVar -> combineConstVar(const1, const2) + const2 is CtInferVar -> combineConstVar(const2, const1) + else -> combineConstsNoVars(const1, const2) + } + + private fun combineConstVar(const1: CtInferVar, const2: Const): CoerceResult { + if (const2 is CtInferVar) { + constUnificationTable.unifyVarVar(const1, const2) + } else { + val const1r = constUnificationTable.findRoot(const1) + constUnificationTable.unifyVarValue(const1r, const2) + } + return CoerceResult.Ok + } + + private fun combineConstsNoVars(const1: Const, const2: Const): CoerceResult = + when { + const1 === const2 -> CoerceResult.Ok + const1 is CtUnknown || const2 is CtUnknown -> CoerceResult.Ok + const1 is CtUnevaluated || const2 is CtUnevaluated -> CoerceResult.Ok + const1 == const2 -> CoerceResult.Ok + else -> CoerceResult.ConstMismatch(const1, const2) + } + + fun combineTypePairs(pairs: List>): CoerceResult = combinePairs(pairs, ::combineTypes) + + fun combineConstPairs(pairs: List>): CoerceResult = combinePairs(pairs, ::combineConsts) + + private fun combinePairs(pairs: List>, combine: (T, T) -> CoerceResult): CoerceResult { var canUnify: CoerceResult = CoerceResult.Ok - for ((t1, t2) in pairs) { - canUnify = combineTypes(t1, t2).and { canUnify } + for ((k1, k2) in pairs) { + canUnify = combine(k1, k2).and { canUnify } } return canUnify } @@ -492,45 +531,89 @@ class RsInferenceContext( combineTypes(ref1.selfTy, ref2.selfTy).isOk && ref1.trait.subst.zipTypeValues(ref2.trait.subst).all { (a, b) -> combineTypes(a, b).isOk + } && + ref1.trait.subst.zipConstValues(ref2.trait.subst).all { (a, b) -> + combineConsts(a, b).isOk } fun combineBoundElements(be1: BoundElement, be2: BoundElement): Boolean = be1.element == be2.element && - combinePairs(be1.subst.zipTypeValues(be2.subst)).isOk && - combinePairs(zipValues(be1.assoc, be2.assoc)).isOk + combineTypePairs(be1.subst.zipTypeValues(be2.subst)).isOk && + combineConstPairs(be1.subst.zipConstValues(be2.subst)).isOk && + combineTypePairs(zipValues(be1.assoc, be2.assoc)).isOk - fun shallowResolve(ty: Ty): Ty { - if (ty !is TyInfer) return ty + fun > shallowResolve(value: T): T = value.foldWith(shallowResolver) - return when (ty) { - is TyInfer.IntVar -> intUnificationTable.findValue(ty) ?: ty - is TyInfer.FloatVar -> floatUnificationTable.findValue(ty) ?: ty - is TyInfer.TyVar -> varUnificationTable.findValue(ty)?.let(this::shallowResolve) ?: ty - } - } + private inner class ShallowResolver : TypeFolder { - fun > resolveTypeVarsIfPossible(ty: T): T { - return ty.foldTyInferWith(this::shallowResolve) - } + override fun foldTy(ty: Ty): Ty = shallowResolve(ty) + + override fun foldConst(const: Const): Const = + if (const is CtInferVar) { + constUnificationTable.findValue(const) ?: const + } else { + const + } - fun > fullyResolve(ty: T): T { - fun go(ty: Ty): Ty { + private fun shallowResolve(ty: Ty): Ty { if (ty !is TyInfer) return ty return when (ty) { - is TyInfer.IntVar -> intUnificationTable.findValue(ty) ?: TyUnknown - is TyInfer.FloatVar -> floatUnificationTable.findValue(ty) ?: TyUnknown - is TyInfer.TyVar -> varUnificationTable.findValue(ty)?.let(::go) ?: TyUnknown + is TyInfer.IntVar -> intUnificationTable.findValue(ty) ?: ty + is TyInfer.FloatVar -> floatUnificationTable.findValue(ty) ?: ty + is TyInfer.TyVar -> varUnificationTable.findValue(ty)?.let(this::shallowResolve) ?: ty } } + } + + private val shallowResolver: ShallowResolver = ShallowResolver() + + fun > resolveTypeVarsIfPossible(value: T): T = value.foldWith(opportunisticVarResolver) - return ty.foldTyInferWith(::go) + private inner class OpportunisticVarResolver : TypeFolder { + override fun foldTy(ty: Ty): Ty { + if (!ty.needsInfer) return ty + val res = shallowResolve(ty) + return res.superFoldWith(this) + } + + override fun foldConst(const: Const): Const { + if (!const.hasCtInfer) return const + val res = shallowResolve(const) + return res.superFoldWith(this) + } } - fun typeVarForParam(ty: TyTypeParameter): Ty { - return TyInfer.TyVar(ty) + private val opportunisticVarResolver: OpportunisticVarResolver = OpportunisticVarResolver() + + /** + * Full type resolution replaces all type and const variables with their concrete results. + */ + fun > fullyResolve(value: T): T { + return value.foldWith(fullTypeResolver) } + private inner class FullTypeResolver : TypeFolder { + override fun foldTy(ty: Ty): Ty { + if (!ty.needsInfer) return ty + val res = shallowResolve(ty) + return if (res is TyInfer) TyUnknown else res.superFoldWith(this) + } + + override fun foldConst(const: Const): Const = + if (const is CtInferVar) { + constUnificationTable.findValue(const) ?: CtUnknown + } else { + const + } + } + + private val fullTypeResolver: FullTypeResolver = FullTypeResolver() + + fun typeVarForParam(ty: TyTypeParameter): Ty = TyInfer.TyVar(ty) + + fun constVarForParam(const: CtConstParameter): Const = CtInferVar(const) + /** Deeply normalize projection types. See [normalizeProjectionType] */ fun > normalizeAssociatedTypesIn(ty: T, recursionDepth: Int = 0): TyWithObligations { val obligations = mutableListOf() @@ -694,14 +777,17 @@ class RsInferenceContext( fun instantiateBounds( element: RsGenericDeclaration, selfTy: Ty? = null, - typeParameters: Substitution = emptySubstitution + subst: Substitution = emptySubstitution ): Substitution { val map = run { - val map = element + val typeSubst = element .generics .associateWith { typeVarForParam(it) } .let { if (selfTy != null) it + (TyTypeParameter.self() to selfTy) else it } - typeParameters + map.toTypeSubst() + val constSubst = element + .constGenerics + .associateWith { constVarForParam(it) } + subst + Substitution(typeSubst = typeSubst, constSubst = constSubst) } instantiateBounds(element.bounds, map).forEach(fulfill::registerPredicateObligation) return map @@ -722,7 +808,10 @@ class RsInferenceContext( /** Checks that [selfTy] satisfies all trait bounds of the [impl] */ private fun canEvaluateBounds(impl: RsImplItem, selfTy: Ty): Boolean { val ff = FulfillmentContext(this, lookup) - val subst = impl.generics.associateWith { typeVarForParam(it) }.toTypeSubst() + val subst = Substitution( + typeSubst = impl.generics.associateWith { typeVarForParam(it) }, + constSubst = impl.constGenerics.associateWith { constVarForParam(it) } + ) return probe { instantiateBounds(impl.bounds, subst).forEach(ff::registerPredicateObligation) impl.typeReference?.type?.substitute(subst)?.let { combineTypes(selfTy, it) } @@ -770,7 +859,10 @@ class RsInferenceContext( // Method path refinement needed if there are multiple impls of the same trait to the same type val trait = source.value as RsTraitItem val typeParameters = instantiateBounds(trait) - val subst = trait.generics.associateBy { it }.toTypeSubst() + val subst = Substitution( + typeSubst = trait.generics.associateBy { it }, + constSubst = trait.constGenerics.associateBy { it } + ) val boundTrait = BoundElement(trait, subst).substitute(typeParameters) val traitRef = TraitRef(callee.selfTy, boundTrait) fulfill.registerPredicateObligation(Obligation(Predicate.Trait(traitRef))) @@ -789,6 +881,9 @@ class RsInferenceContext( val RsGenericDeclaration.generics: List get() = typeParameters.map { TyTypeParameter.named(it) } +val RsGenericDeclaration.constGenerics: List + get() = constParameters.map { CtConstParameter(it) } + val RsGenericDeclaration.bounds: List get() = CachedValuesManager.getCachedValue(this) { CachedValueProvider.Result.create( @@ -855,15 +950,13 @@ sealed class ResolvedPath { sealed class CoerceResult { object Ok : CoerceResult() - class Mismatch(val ty1: Ty, val ty2: Ty) : CoerceResult() + class TypeMismatch(val ty1: Ty, val ty2: Ty) : CoerceResult() + class ConstMismatch(val const1: Const, val const2: Const) : CoerceResult() - val isOk: Boolean get() = this == Ok + val isOk: Boolean get() = this is Ok } -inline fun CoerceResult.and(rhs: () -> CoerceResult): CoerceResult = when (this) { - is CoerceResult.Mismatch -> this - CoerceResult.Ok -> rhs() -} +inline fun CoerceResult.and(rhs: () -> CoerceResult): CoerceResult = if (isOk) rhs() else this object TypeInferenceMarks { val cyclicType = Testmark("cyclicType") diff --git a/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt b/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt index 70cf61fb5cd..f5ed14b81c9 100644 --- a/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt +++ b/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt @@ -15,10 +15,16 @@ import org.rust.lang.core.resolve.* import org.rust.lang.core.resolve.ref.* import org.rust.lang.core.stubs.RsStubLiteralKind import org.rust.lang.core.types.* +import org.rust.lang.core.types.consts.Const +import org.rust.lang.core.types.consts.CtConstParameter +import org.rust.lang.core.types.consts.CtInferVar +import org.rust.lang.core.types.consts.CtUnknown import org.rust.lang.core.types.ty.* import org.rust.lang.utils.RsDiagnostic -import org.rust.lang.utils.evaluation.ExprValue -import org.rust.lang.utils.evaluation.RsConstExprEvaluator +import org.rust.lang.utils.evaluation.ConstExpr +import org.rust.lang.utils.evaluation.PathExprResolver +import org.rust.lang.utils.evaluation.evaluate +import org.rust.lang.utils.evaluation.toConst import org.rust.openapiext.forEachChild import org.rust.stdext.notEmptyOrLet import org.rust.stdext.singleOrFilter @@ -37,9 +43,9 @@ class RsTypeInferenceWalker( private val RsStructLiteralField.type: Ty get() = resolveToDeclaration()?.typeReference?.type ?: TyUnknown private fun resolveTypeVarsWithObligations(ty: Ty): Ty { - if (!ty.hasTyInfer) return ty + if (!ty.needsInfer) return ty val tyRes = ctx.resolveTypeVarsIfPossible(ty) - if (!tyRes.hasTyInfer) return tyRes + if (!tyRes.needsInfer) return tyRes selectObligationsWherePossible() return ctx.resolveTypeVarsIfPossible(tyRes) } @@ -203,34 +209,35 @@ class RsTypeInferenceWalker( resolveTypeVarsWithObligations(expected) ) - private fun coerceResolved(element: RsElement, inferred: Ty, expected: Ty): Boolean { + private fun coerceResolved(element: RsElement, inferred: Ty, expected: Ty): Boolean = when (val result = tryCoerce(inferred, expected)) { - CoerceResult.Ok -> return true - - is CoerceResult.Mismatch -> { - // ignoring possible false-positives (it's only basic experimental type checking) - val ignoredTys = listOf( - TyUnknown::class.java, - TyInfer.TyVar::class.java, - TyTypeParameter::class.java, - TyProjection::class.java, - TyTraitObject::class.java, - TyAnon::class.java - ) - - if (result.ty1.javaClass !in ignoredTys && result.ty2.javaClass !in ignoredTys + CoerceResult.Ok -> true + + is CoerceResult.TypeMismatch -> { + if (result.ty1.javaClass !in IGNORED_TYS && result.ty2.javaClass !in IGNORED_TYS && !(expected is TyReference && inferred is TyReference - && (expected.containsTyOfClass(ignoredTys) || inferred.containsTyOfClass(ignoredTys))) + && (expected.containsTyOfClass(IGNORED_TYS) || inferred.containsTyOfClass(IGNORED_TYS))) ) { - // another awful hack: check that inner expressions did not annotated as an error - // to disallow annotation intersections. This should be done in a different way - if (ctx.diagnostics.all { !element.isAncestorOf(it.element) }) { - ctx.reportTypeMismatch(element, expected, inferred) - } + reportTypeMismatch(element, expected, inferred) } - return false + false } + + is CoerceResult.ConstMismatch -> { + if (result.const1.javaClass !in IGNORED_CONSTS && result.const2.javaClass !in IGNORED_CONSTS) { + reportTypeMismatch(element, expected, inferred) + } + + false + } + } + + // Another awful hack: check that inner expressions did not annotated as an error + // to disallow annotation intersections. This should be done in a different way + private fun reportTypeMismatch(element: RsElement, expected: Ty, inferred: Ty) { + if (ctx.diagnostics.all { !element.isAncestorOf(it.element) }) { + ctx.reportTypeMismatch(element, expected, inferred) } } @@ -276,7 +283,7 @@ class RsTypeInferenceWalker( if (ctx.combineTypesIfOk(derefTyRef, expected)) return CoerceResult.Ok } - return CoerceResult.Mismatch(inferred, expected) + return CoerceResult.TypeMismatch(inferred, expected) } private fun inferLitExprType(expr: RsLitExpr, expected: Ty?): Ty { @@ -286,7 +293,9 @@ class RsTypeInferenceWalker( is RsStubLiteralKind.String -> { // TODO infer the actual lifetime if (stubKind.isByte) { - TyReference(TyArray(TyInteger.U8, stubKind.value?.length?.toLong() ?: 0), Mutability.IMMUTABLE) + val size = stubKind.value?.length?.toLong() + val const = size?.let { ConstExpr.Value.Integer(it, TyInteger.USize).toConst() } ?: CtUnknown + TyReference(TyArray(TyInteger.U8, const), Mutability.IMMUTABLE) } else { TyReference(TyStr, Mutability.IMMUTABLE) } @@ -369,19 +378,31 @@ class RsTypeInferenceWalker( /** See test `test type arguments remap on collapse to trait` */ private fun collapseSubst(parentFn: RsFunction, variants: List>): Substitution { - //TODO remap lifetimes - val collapsed = mutableMapOf() val generics = parentFn.generics + val typeSubst = mutableMapOf() for (fn in variants) { for ((key, newValue) in generics.zip(fn.positionalTypeArguments)) { @Suppress("NAME_SHADOWING") - collapsed.compute(key) { key, oldValue -> + typeSubst.compute(key) { key, oldValue -> if (oldValue == null || oldValue == newValue) newValue else TyInfer.TyVar(key) } } } - variants.first().subst[TyTypeParameter.self()]?.let { collapsed[TyTypeParameter.self()] = it } - return collapsed.toTypeSubst() + variants.first().subst[TyTypeParameter.self()]?.let { typeSubst[TyTypeParameter.self()] = it } + + val constGenerics = parentFn.constGenerics + val constSubst = mutableMapOf() + for (fn in variants) { + for ((key, newValue) in constGenerics.zip(fn.positionalConstArguments)) { + @Suppress("NAME_SHADOWING") + constSubst.compute(key) { key, oldValue -> + if (oldValue == null || oldValue == newValue) newValue else CtInferVar(key) + } + } + } + + // TODO: remap lifetimes + return Substitution(typeSubst = typeSubst, constSubst = constSubst) } private fun instantiatePath( @@ -390,22 +411,16 @@ class RsTypeInferenceWalker( pathExpr: RsPathExpr ): Ty { val path = pathExpr.path - val subst = instantiatePathGenerics(path, BoundElement(element, scopeEntry.subst)).subst if (element is RsGenericDeclaration) { inferConstArgumentTypes(element.constParameters, path.constArguments) } - val type = when (element) { - is RsPatBinding -> ctx.getBindingType(element) - is RsTypeDeclarationElement -> element.declaredType - is RsEnumVariant -> element.parentEnum.declaredType - is RsFunction -> element.type - is RsConstant -> element.typeReference?.type ?: TyUnknown - is RsConstParameter -> element.typeReference?.type ?: TyUnknown - is RsSelfParameter -> element.typeOfValue - else -> return TyUnknown - } + val subst = instantiatePathGenerics( + path, + BoundElement(element, scopeEntry.subst), + PathExprResolver.fromContext(ctx) + ).subst val typeParameters = when { scopeEntry is AssocItemScopeEntry && element is RsAbstractable -> { @@ -421,7 +436,10 @@ class RsTypeInferenceWalker( val typeParameters = ctx.instantiateBounds(owner.trait) // UFCS - add predicate `Self : Trait` val selfTy = subst[TyTypeParameter.self()] ?: ctx.typeVarForParam(TyTypeParameter.self()) - val newSubst = owner.trait.generics.associateBy { it }.toTypeSubst() + val newSubst = Substitution( + typeSubst = owner.trait.generics.associateBy { it }, + constSubst = owner.trait.constGenerics.associateBy { it } + ) val boundTrait = BoundElement(owner.trait, newSubst) .substitute(typeParameters) val traitRef = TraitRef(selfTy, boundTrait) @@ -449,6 +467,16 @@ class RsTypeInferenceWalker( unifySubst(subst, typeParameters) + val type = when (element) { + is RsPatBinding -> ctx.getBindingType(element) + is RsTypeDeclarationElement -> element.declaredType + is RsEnumVariant -> element.parentEnum.declaredType + is RsFunction -> element.type + is RsConstant -> element.typeReference?.type ?: TyUnknown + is RsConstParameter -> element.typeReference?.type ?: TyUnknown + is RsSelfParameter -> element.typeOfValue + else -> return TyUnknown + } val tupleFields = (element as? RsFieldsOwner)?.tupleFields return if (tupleFields != null) { // Treat tuple constructor as a function @@ -478,6 +506,13 @@ class RsTypeInferenceWalker( } } } + subst1.constSubst.forEach { (k, c1) -> + subst2[k]?.let { c2 -> + if (k != c1 && c1 !is CtConstParameter && c1 !is CtUnknown) { + ctx.combineConsts(c2, c1) + } + } + } // TODO take into account the lifetimes } @@ -565,7 +600,11 @@ class RsTypeInferenceWalker( return calleeType.retType } - private fun inferMethodCallExprType(receiver: Ty, methodCall: RsMethodCall, expected: Ty?): Ty { + private fun inferMethodCallExprType( + receiver: Ty, + methodCall: RsMethodCall, + expected: Ty? + ): Ty { val argExprs = methodCall.valueArgumentList.exprList val callee = run { val variants = resolveMethodCallReferenceWithReceiverType(lookup, receiver, methodCall) @@ -584,7 +623,7 @@ class RsTypeInferenceWalker( inferConstArgumentTypes(callee.element.constParameters, methodCall.constArguments) - var typeParameters = ctx.instantiateMethodOwnerSubstitution(callee, methodCall) + var newSubst = ctx.instantiateMethodOwnerSubstitution(callee, methodCall) // TODO: borrow adjustments for self parameter /* @@ -594,23 +633,25 @@ class RsTypeInferenceWalker( } */ - typeParameters = ctx.instantiateBounds(callee.element, callee.selfTy, typeParameters) + newSubst = ctx.instantiateBounds(callee.element, callee.selfTy, newSubst) - val fnSubst = run { - val typeArguments = methodCall.typeArgumentList?.typeReferenceList.orEmpty().map { it.type } - if (typeArguments.isEmpty()) { - emptySubstitution - } else { - val parameters = callee.element.typeParameterList?.typeParameterList.orEmpty() - .map { TyTypeParameter.named(it) } - parameters.zip(typeArguments).toMap().toTypeSubst() - } + val typeParameters = callee.element.typeParameters.map { TyTypeParameter.named(it) } + val typeArguments = methodCall.typeArguments.map { it.type } + val typeSubst = typeParameters.zip(typeArguments).toMap() + + val constParameters = callee.element.constParameters.map { CtConstParameter(it) } + val resolver = PathExprResolver.fromContext(ctx) + val constArguments = methodCall.constArguments.withIndex().map { (i, expr) -> + val expectedTy = constParameters.getOrNull(i)?.parameter?.typeReference?.type ?: TyUnknown + expr.evaluate(expectedTy, resolver) } + val constSubst = constParameters.zip(constArguments).toMap() - unifySubst(fnSubst, typeParameters) + val fnSubst = Substitution(typeSubst = typeSubst, constSubst = constSubst) + unifySubst(fnSubst, newSubst) val methodType = (callee.element.type) - .substitute(typeParameters) + .substitute(newSubst) .foldWith(associatedTypeNormalizer) as TyFunction if (expected != null && !callee.element.isAsync) ctx.combineTypes(expected, methodType.retType) // drop first element of paramTypes because it's `self` param @@ -779,7 +820,7 @@ class RsTypeInferenceWalker( private fun inferLabeledExprType(expr: RsLabeledExpression, baseType: Ty, matchOnlyByLabel: Boolean): Ty { val returningTypes = mutableListOf(baseType) - val label = expr.labelDecl?.name + val label = expr.takeIf { it.block?.stub == null }?.labelDecl?.name fun collectReturningTypes(element: PsiElement, matchOnlyByLabel: Boolean) { element.forEachChild { child -> @@ -1140,7 +1181,7 @@ class RsTypeInferenceWalker( private fun inferIncludeMacro(macroCall: RsMacroCall): Ty { return when (macroCall.macroName) { "include_str" -> TyReference(TyStr, Mutability.IMMUTABLE) - "include_bytes" -> TyReference(TyArray(TyInteger.U8, null), Mutability.IMMUTABLE) + "include_bytes" -> TyReference(TyArray(TyInteger.U8, CtUnknown), Mutability.IMMUTABLE) else -> TyUnknown } } @@ -1227,18 +1268,19 @@ class RsTypeInferenceWalker( ?: return TySlice(TyUnknown) val sizeExpr = expr.sizeExpr sizeExpr?.inferType(TyInteger.USize) - val size = if (sizeExpr != null) { - val exprValue = RsConstExprEvaluator.evaluate(sizeExpr, TyInteger.USize) { - ctx.getResolvedPath(it).singleOrNull()?.element - } - (exprValue as? ExprValue.Integer)?.value - } else { - null - } + val size = sizeExpr?.evaluate(TyInteger.USize, PathExprResolver.fromContext(ctx)) ?: CtUnknown elementType to size } else { val elementTypes = expr.arrayElements?.map { it.inferType(expectedElemTy) } - if (elementTypes.isNullOrEmpty()) return TyArray(TyInfer.TyVar(), 0) + val size = if (elementTypes != null) { + val size = elementTypes.size.toLong() + ConstExpr.Value.Integer(size, TyInteger.USize).toConst() + } else { + CtUnknown + } + if (elementTypes.isNullOrEmpty()) { + return TyArray(TyInfer.TyVar(), size.foldCtConstParameterWith { CtInferVar(it) }) + } // '!!' is safe here because we've just checked that elementTypes isn't null val elementType = getMoreCompleteType(elementTypes!!) @@ -1247,7 +1289,7 @@ class RsTypeInferenceWalker( } else { elementType } - inferredTy to elementTypes.size.toLong() + inferredTy to size } return TyArray(elementType, size) @@ -1328,6 +1370,24 @@ class RsTypeInferenceWalker( val selection = lookup.selectProjection(TraitRef(this, futureTrait.withSubst()), outputType) return selection.ok()?.register() ?: TyUnknown } + + companion object { + // ignoring possible false-positives (it's only basic experimental type checking) + + val IGNORED_TYS: List> = listOf( + TyUnknown::class.java, + TyInfer.TyVar::class.java, + TyTypeParameter::class.java, + TyProjection::class.java, + TyTraitObject::class.java, + TyAnon::class.java + ) + + val IGNORED_CONSTS: List> = listOf( + CtUnknown::class.java, + CtInferVar::class.java + ) + } } private val RsSelfParameter.typeOfValue: Ty diff --git a/src/main/kotlin/org/rust/lang/core/types/regions/ReEarlyBound.kt b/src/main/kotlin/org/rust/lang/core/types/regions/ReEarlyBound.kt index b791635a000..9acca8be422 100644 --- a/src/main/kotlin/org/rust/lang/core/types/regions/ReEarlyBound.kt +++ b/src/main/kotlin/org/rust/lang/core/types/regions/ReEarlyBound.kt @@ -5,17 +5,23 @@ package org.rust.lang.core.types.regions +import com.intellij.codeInsight.completion.CompletionUtil import org.rust.lang.core.psi.RsLifetimeParameter import org.rust.lang.core.types.HAS_RE_EARLY_BOUND_MASK import org.rust.lang.core.types.TypeFlags /** - * Region bound in a type or fn declaration which will be - * substituted 'early' -- that is, at the same time when type - * parameters are substituted. + * Region bound in a type or fn declaration, which will be substituted 'early' -- that is, + * at the same time when type parameters are substituted. */ -data class ReEarlyBound(val parameter: RsLifetimeParameter) : Region() { +class ReEarlyBound(parameter: RsLifetimeParameter) : Region() { + val parameter: RsLifetimeParameter = CompletionUtil.getOriginalOrSelf(parameter) + override val flags: TypeFlags = HAS_RE_EARLY_BOUND_MASK + override fun equals(other: Any?): Boolean = other is ReEarlyBound && other.parameter == parameter + + override fun hashCode(): Int = parameter.hashCode() + override fun toString(): String = parameter.name ?: "" } diff --git a/src/main/kotlin/org/rust/lang/core/types/ty/TyAdt.kt b/src/main/kotlin/org/rust/lang/core/types/ty/TyAdt.kt index 4a090dbd62d..879224cc49f 100644 --- a/src/main/kotlin/org/rust/lang/core/types/ty/TyAdt.kt +++ b/src/main/kotlin/org/rust/lang/core/types/ty/TyAdt.kt @@ -8,10 +8,14 @@ package org.rust.lang.core.types.ty import com.intellij.codeInsight.completion.CompletionUtil import org.rust.lang.core.psi.RsTypeAlias import org.rust.lang.core.psi.ext.RsStructOrEnumItemElement +import org.rust.lang.core.psi.ext.constParameters import org.rust.lang.core.psi.ext.lifetimeParameters import org.rust.lang.core.psi.ext.typeParameters import org.rust.lang.core.types.BoundElement import org.rust.lang.core.types.Substitution +import org.rust.lang.core.types.consts.Const +import org.rust.lang.core.types.consts.CtConstParameter +import org.rust.lang.core.types.consts.CtUnknown import org.rust.lang.core.types.infer.TypeFolder import org.rust.lang.core.types.infer.TypeVisitor import org.rust.lang.core.types.mergeFlags @@ -22,15 +26,16 @@ import org.rust.lang.core.types.regions.Region /** * Represents struct/enum/union. * "ADT" may be read as "Algebraic Data Type". - * The name is inspired by rustc + * The name is inspired by rustc. */ @Suppress("DataClassPrivateConstructor") data class TyAdt private constructor( val item: RsStructOrEnumItemElement, val typeArguments: List, val regionArguments: List, + val constArguments: List, val aliasedBy: BoundElement? -) : Ty(mergeFlags(typeArguments) or mergeFlags(regionArguments)) { +) : Ty(mergeFlags(typeArguments) or mergeFlags(regionArguments) or mergeFlags(constArguments)) { // This method is rarely called (in comparison with folding), so we can implement it in a such inefficient way. override val typeParameterValues: Substitution @@ -41,7 +46,10 @@ data class TyAdt private constructor( val regionSubst = item.lifetimeParameters.withIndex().associate { (i, param) -> ReEarlyBound(param) to regionArguments.getOrElse(i) { ReUnknown } } - return Substitution(typeSubst, regionSubst) + val constSubst = item.constParameters.withIndex().associate { (i, param) -> + CtConstParameter(param) to constArguments.getOrElse(i) { CtUnknown } + } + return Substitution(typeSubst, regionSubst, constSubst) } override fun superFoldWith(folder: TypeFolder): TyAdt = @@ -49,20 +57,27 @@ data class TyAdt private constructor( item, typeArguments.map { it.foldWith(folder) }, regionArguments.map { it.foldWith(folder) }, + constArguments.map { it.foldWith(folder) }, aliasedBy ) override fun superVisitWith(visitor: TypeVisitor): Boolean = - typeArguments.any { it.visitWith(visitor) } || regionArguments.any { it.visitWith(visitor) } + typeArguments.any { it.visitWith(visitor) } || + regionArguments.any { it.visitWith(visitor) } || + constArguments.any { it.visitWith(visitor) } fun withAlias(aliasedBy: BoundElement?): TyAdt = copy(aliasedBy = aliasedBy) companion object { - fun valueOf(struct: RsStructOrEnumItemElement): TyAdt { - val item = CompletionUtil.getOriginalOrSelf(struct) - return TyAdt(item, defaultTypeArguments(struct), defaultRegionArguments(struct), null) - } + fun valueOf(struct: RsStructOrEnumItemElement): TyAdt = + TyAdt( + CompletionUtil.getOriginalOrSelf(struct), + defaultTypeArguments(struct), + defaultRegionArguments(struct), + defaultConstArguments(struct), + null + ) } } @@ -71,3 +86,6 @@ private fun defaultTypeArguments(item: RsStructOrEnumItemElement): List = private fun defaultRegionArguments(item: RsStructOrEnumItemElement): List = item.lifetimeParameters.map { param -> ReEarlyBound(param) } + +private fun defaultConstArguments(item: RsStructOrEnumItemElement): List = + item.constParameters.map { param -> CtConstParameter(param) } diff --git a/src/main/kotlin/org/rust/lang/core/types/ty/TyArray.kt b/src/main/kotlin/org/rust/lang/core/types/ty/TyArray.kt index bdaae032dae..29edd059a53 100644 --- a/src/main/kotlin/org/rust/lang/core/types/ty/TyArray.kt +++ b/src/main/kotlin/org/rust/lang/core/types/ty/TyArray.kt @@ -5,13 +5,17 @@ package org.rust.lang.core.types.ty +import org.rust.lang.core.types.consts.Const +import org.rust.lang.core.types.consts.asLong import org.rust.lang.core.types.infer.TypeFolder import org.rust.lang.core.types.infer.TypeVisitor -data class TyArray(val base: Ty, val size: Long?) : Ty(base.flags) { +data class TyArray(val base: Ty, val const: Const) : Ty(base.flags or const.flags) { + val size: Long? get() = const.asLong() + override fun superFoldWith(folder: TypeFolder): Ty = - TyArray(base.foldWith(folder), size) + TyArray(base.foldWith(folder), const.foldWith(folder)) override fun superVisitWith(visitor: TypeVisitor): Boolean = - base.visitWith(visitor) + base.visitWith(visitor) || const.visitWith(visitor) } diff --git a/src/main/kotlin/org/rust/lang/utils/evaluation/ConstExpr.kt b/src/main/kotlin/org/rust/lang/utils/evaluation/ConstExpr.kt new file mode 100644 index 00000000000..16e179fcadb --- /dev/null +++ b/src/main/kotlin/org/rust/lang/utils/evaluation/ConstExpr.kt @@ -0,0 +1,92 @@ +/* + * Use of this source code is governed by the MIT license that can be + * found in the LICENSE file. + */ + +package org.rust.lang.utils.evaluation + +import org.rust.lang.core.psi.ext.BinaryOperator +import org.rust.lang.core.psi.ext.UnaryOperator +import org.rust.lang.core.types.TypeFlags +import org.rust.lang.core.types.consts.Const +import org.rust.lang.core.types.consts.CtUnevaluated +import org.rust.lang.core.types.consts.CtUnknown +import org.rust.lang.core.types.consts.CtValue +import org.rust.lang.core.types.infer.TypeFoldable +import org.rust.lang.core.types.infer.TypeFolder +import org.rust.lang.core.types.infer.TypeVisitor +import org.rust.lang.core.types.ty.* + +fun ConstExpr<*>.toConst(): Const = + when (this) { + is ConstExpr.Constant -> const + is ConstExpr.Value -> CtValue(this) + is ConstExpr.Error -> CtUnknown + else -> CtUnevaluated(this) + } + +sealed class ConstExpr(val flags: TypeFlags = 0) : TypeFoldable> { + abstract val expectedTy: T? + + data class Unary( + val operator: UnaryOperator, + val expr: ConstExpr, + override val expectedTy: T + ) : ConstExpr(expr.flags) { + override fun superFoldWith(folder: TypeFolder): Unary = Unary(operator, expr.foldWith(folder), expectedTy) + override fun superVisitWith(visitor: TypeVisitor): Boolean = expr.visitWith(visitor) + } + + data class Binary( + val left: ConstExpr, + val operator: BinaryOperator, + val right: ConstExpr, + override val expectedTy: T + ) : ConstExpr(left.flags or right.flags) { + override fun superFoldWith(folder: TypeFolder): Binary = + Binary(left.foldWith(folder), operator, right.foldWith(folder), expectedTy) + + override fun superVisitWith(visitor: TypeVisitor): Boolean = left.visitWith(visitor) || right.visitWith(visitor) + } + + data class Constant( + val const: Const, + override val expectedTy: T + ) : ConstExpr(const.flags) { + override fun superFoldWith(folder: TypeFolder): Constant = Constant(const.foldWith(folder), expectedTy) + override fun superVisitWith(visitor: TypeVisitor): Boolean = const.visitWith(visitor) + } + + sealed class Value : ConstExpr() { + override fun superFoldWith(folder: TypeFolder): Value = this + override fun superVisitWith(visitor: TypeVisitor): Boolean = false + + data class Bool(val value: Boolean) : Value() { + override val expectedTy: TyBool = TyBool + override fun toString(): String = value.toString() + } + + data class Integer(val value: Long, override val expectedTy: TyInteger) : Value() { + override fun toString(): String = value.toString() + } + + data class Float(val value: Double, override val expectedTy: TyFloat) : Value() { + override fun toString(): String = value.toString() + } + + data class Char(val value: String) : Value() { + override val expectedTy: TyChar = TyChar + override fun toString(): String = value + } + + data class Str(val value: String, override val expectedTy: TyReference) : Value() { + override fun toString(): String = value + } + } + + class Error : ConstExpr() { + override val expectedTy: T? = null + override fun superFoldWith(folder: TypeFolder): ConstExpr = this + override fun superVisitWith(visitor: TypeVisitor): Boolean = false + } +} diff --git a/src/main/kotlin/org/rust/lang/utils/evaluation/ConstExprBuilder.kt b/src/main/kotlin/org/rust/lang/utils/evaluation/ConstExprBuilder.kt new file mode 100644 index 00000000000..f8c56889ae3 --- /dev/null +++ b/src/main/kotlin/org/rust/lang/utils/evaluation/ConstExprBuilder.kt @@ -0,0 +1,169 @@ +/* + * Use of this source code is governed by the MIT license that can be + * found in the LICENSE file. + */ + +package org.rust.lang.utils.evaluation + +import org.rust.lang.core.psi.* +import org.rust.lang.core.psi.ext.* +import org.rust.lang.core.types.consts.CtConstParameter +import org.rust.lang.core.types.ty.* +import org.rust.lang.core.types.type + +fun RsExpr.toConstExpr( + expectedTy: Ty = type, + resolver: PathExprResolver? = PathExprResolver.default +): ConstExpr? { + val builder = when (expectedTy) { + is TyInteger -> IntegerConstExprBuilder(expectedTy, resolver) + is TyBool -> BoolConstExprBuilder(resolver) + is TyFloat -> FloatConstExprBuilder(expectedTy, resolver) + is TyChar -> CharConstExprBuilder(resolver) + // TODO: type should be "wider" + STR_REF_TYPE -> StrConstExprBuilder(resolver) + else -> null + } + return builder?.build(this) +} + +private val STR_REF_TYPE: TyReference = TyReference(TyStr, Mutability.IMMUTABLE) + +private abstract class ConstExprBuilder { + protected abstract val expectedTy: T + protected abstract val resolver: PathExprResolver? + protected abstract val RsLitExpr.value: V? + protected abstract fun V.wrap(): ConstExpr + + protected fun makeLeafValue(expr: RsLitExpr): ConstExpr? = expr.value?.wrap() + + protected fun makeLeafParameter(parameter: RsConstParameter): ConstExpr = + ConstExpr.Constant(CtConstParameter(parameter), expectedTy) + + fun build(expr: RsExpr?): ConstExpr? = build(expr, 0) + + protected fun build(expr: RsExpr?, depth: Int): ConstExpr? { + // To prevent SO we restrict max depth of expression + if (depth >= MAX_EXPR_DEPTH) return null + return buildInner(expr, depth) + } + + protected open fun buildInner(expr: RsExpr?, depth: Int): ConstExpr? { + return when (expr) { + is RsLitExpr -> makeLeafValue(expr) + is RsParenExpr -> build(expr.expr, depth + 1) + is RsBlockExpr -> build(expr.block.expr, depth + 1) + is RsPathExpr -> { + val element = resolver?.invoke(expr) + + val typeReference = when (element) { + is RsConstant -> element.typeReference?.takeIf { element.isConst } + is RsConstParameter -> element.typeReference + else -> null + } + + val typeElementPath = (typeReference?.typeElement as? RsBaseType)?.path ?: return null + if (TyPrimitive.fromPath(typeElementPath) != expectedTy) return null + + when (element) { + is RsConstant -> build(element.expr, depth + 1) + is RsConstParameter -> makeLeafParameter(element) + else -> null + } + } + else -> null + } + } + + companion object { + private const val MAX_EXPR_DEPTH: Int = 64 + } +} + +private class IntegerConstExprBuilder( + override val expectedTy: TyInteger, + override val resolver: PathExprResolver? +) : ConstExprBuilder() { + override val RsLitExpr.value: Long? get() = integerValue + override fun Long.wrap(): ConstExpr.Value.Integer = ConstExpr.Value.Integer(this, expectedTy) + + override fun buildInner(expr: RsExpr?, depth: Int): ConstExpr? { + return when (expr) { + is RsUnaryExpr -> { + if (expr.operatorType != UnaryOperator.MINUS) return null + val value = build(expr.expr, depth + 1) ?: return null + ConstExpr.Unary(UnaryOperator.MINUS, value, expectedTy) + } + is RsBinaryExpr -> { + val op = expr.operatorType as? ArithmeticOp ?: return null + val lhs = build(expr.left, depth + 1) ?: return null + val rhs = build(expr.right, depth + 1) ?: return null + ConstExpr.Binary(lhs, op, rhs, expectedTy) + } + else -> super.buildInner(expr, depth) + } + } +} + +private class BoolConstExprBuilder( + override val resolver: PathExprResolver? +) : ConstExprBuilder() { + override val expectedTy: TyBool = TyBool + override val RsLitExpr.value: Boolean? get() = booleanValue + override fun Boolean.wrap(): ConstExpr.Value.Bool = ConstExpr.Value.Bool(this) + + override fun buildInner(expr: RsExpr?, depth: Int): ConstExpr? { + return when (expr) { + is RsBinaryExpr -> when (expr.operatorType) { + LogicOp.AND -> { + val lhs = build(expr.left, depth + 1) ?: return null + val rhs = build(expr.right, depth + 1) ?: ConstExpr.Error() + ConstExpr.Binary(lhs, LogicOp.AND, rhs, expectedTy) + } + LogicOp.OR -> { + val lhs = build(expr.left, depth + 1) ?: return null + val rhs = build(expr.right, depth + 1) ?: ConstExpr.Error() + ConstExpr.Binary(lhs, LogicOp.OR, rhs, expectedTy) + } + ArithmeticOp.BIT_XOR -> { + val lhs = build(expr.left, depth + 1) ?: return null + val rhs = build(expr.right, depth + 1) ?: return null + ConstExpr.Binary(lhs, ArithmeticOp.BIT_XOR, rhs, expectedTy) + } + else -> null + } + is RsUnaryExpr -> when (expr.operatorType) { + UnaryOperator.NOT -> { + val value = build(expr.expr, depth + 1) ?: return null + ConstExpr.Unary(UnaryOperator.NOT, value, expectedTy) + } + else -> null + } + else -> super.buildInner(expr, depth) + } + } +} + +private class FloatConstExprBuilder( + override val expectedTy: TyFloat, + override val resolver: PathExprResolver? +) : ConstExprBuilder() { + override val RsLitExpr.value: Double? get() = floatValue + override fun Double.wrap(): ConstExpr.Value.Float = ConstExpr.Value.Float(this, expectedTy) +} + +private class CharConstExprBuilder( + override val resolver: PathExprResolver? +) : ConstExprBuilder() { + override val expectedTy: TyChar = TyChar + override val RsLitExpr.value: String? get() = charValue + override fun String.wrap(): ConstExpr.Value.Char = ConstExpr.Value.Char(this) +} + +private class StrConstExprBuilder( + override val resolver: PathExprResolver? +) : ConstExprBuilder() { + override val expectedTy: TyReference = STR_REF_TYPE + override val RsLitExpr.value: String? get() = stringValue + override fun String.wrap(): ConstExpr.Value.Str = ConstExpr.Value.Str(this, expectedTy) +} diff --git a/src/main/kotlin/org/rust/lang/utils/evaluation/ConstExprEvaluator.kt b/src/main/kotlin/org/rust/lang/utils/evaluation/ConstExprEvaluator.kt new file mode 100644 index 00000000000..2aae55f13bf --- /dev/null +++ b/src/main/kotlin/org/rust/lang/utils/evaluation/ConstExprEvaluator.kt @@ -0,0 +1,185 @@ +/* + * Use of this source code is governed by the MIT license that can be + * found in the LICENSE file. + */ + +package org.rust.lang.utils.evaluation + +import org.rust.lang.core.psi.RsExpr +import org.rust.lang.core.psi.ext.ArithmeticOp +import org.rust.lang.core.psi.ext.LogicOp +import org.rust.lang.core.psi.ext.UnaryOperator +import org.rust.lang.core.types.consts.Const +import org.rust.lang.core.types.consts.CtUnevaluated +import org.rust.lang.core.types.consts.CtUnknown +import org.rust.lang.core.types.consts.CtValue +import org.rust.lang.core.types.infer.TypeFoldable +import org.rust.lang.core.types.infer.TypeFolder +import org.rust.lang.core.types.infer.needsEval +import org.rust.lang.core.types.ty.Ty +import org.rust.lang.core.types.ty.TyBool +import org.rust.lang.core.types.ty.TyInteger +import org.rust.lang.core.types.type + +fun RsExpr.evaluate( + expectedTy: Ty = type, + resolver: PathExprResolver? = PathExprResolver.default +): Const = toConstExpr(expectedTy, resolver)?.evaluate()?.toConst() ?: CtUnknown + +private fun ConstExpr.evaluate(): ConstExpr = + when (expectedTy) { + is TyBool -> simplifyToBool(this) + is TyInteger -> simplifyToInteger(this) + else -> this + } + +fun TypeFoldable.tryEvaluate(): T = + foldWith(object : TypeFolder { + override fun foldTy(ty: Ty): Ty = + if (ty.needsEval) ty.superFoldWith(this) else ty + + override fun foldConst(const: Const): Const = + if (const is CtUnevaluated && const.needsEval) { + const.expr.evaluate().toConst() + } else { + const + } + }) + +private fun simplifyToBool(expr: ConstExpr): ConstExpr { + val value = when (expr) { + is ConstExpr.Constant -> { + val const = expr.const as? CtValue + val value = const?.expr as? ConstExpr.Value.Bool + value?.value ?: return expr + } + is ConstExpr.Unary -> { + if (expr.operator != UnaryOperator.NOT) return ConstExpr.Error() + when (val result = simplifyToBool(expr.expr)) { + is ConstExpr.Value.Bool -> !result.value + is ConstExpr.Error -> return ConstExpr.Error() + else -> return expr.copy(expr = result) + } + } + is ConstExpr.Binary -> { + val left = simplifyToBool(expr.left) + val right = simplifyToBool(expr.right) + when (expr.operator) { + LogicOp.AND -> + when { + left is ConstExpr.Value.Bool && !left.value -> + false // false && _ --> false + right is ConstExpr.Value.Bool && !right.value -> + false // _ && false --> false + left is ConstExpr.Value.Bool && right is ConstExpr.Value.Bool -> + left.value && right.value + left is ConstExpr.Error || right is ConstExpr.Error -> + return ConstExpr.Error() + else -> + return expr.copy(left = left, right = right) + } + LogicOp.OR -> + when { + left is ConstExpr.Value.Bool && left.value -> + true // true || _ --> true + right is ConstExpr.Value.Bool && right.value -> + true // _ || true --> true + left is ConstExpr.Value.Bool && right is ConstExpr.Value.Bool -> + left.value || right.value + left is ConstExpr.Error || right is ConstExpr.Error -> + return ConstExpr.Error() + else -> + return expr.copy(left = left, right = right) + } + ArithmeticOp.BIT_XOR -> + when { + left is ConstExpr.Value.Bool && right is ConstExpr.Value.Bool -> + left.value xor right.value + left is ConstExpr.Error || right is ConstExpr.Error -> + return ConstExpr.Error() + else -> + return expr.copy(left = left, right = right) + } + else -> + return ConstExpr.Error() + } + } + else -> return expr + } + @Suppress("UNCHECKED_CAST") + return ConstExpr.Value.Bool(value) as ConstExpr +} + +private fun simplifyToInteger(expr: ConstExpr): ConstExpr { + val expectedTy = expr.expectedTy + if (expectedTy !is TyInteger) return ConstExpr.Error() + + val value = when (expr) { + is ConstExpr.Constant -> { + val const = expr.const as? CtValue + val value = const?.expr as? ConstExpr.Value.Integer + value?.value ?: return expr + } + is ConstExpr.Unary -> { + if (expr.operator != UnaryOperator.MINUS) return ConstExpr.Error() + when (val result = simplifyToInteger(expr.expr)) { + is ConstExpr.Value.Integer -> -result.value + is ConstExpr.Error -> return ConstExpr.Error() + else -> return expr.copy(expr = result) + } + } + is ConstExpr.Binary -> { + val left = simplifyToInteger(expr.left) + val right = simplifyToInteger(expr.right) + when { + left is ConstExpr.Value.Integer && right is ConstExpr.Value.Integer -> + // TODO: check overflow + when (expr.operator) { + ArithmeticOp.ADD -> left.value + right.value + ArithmeticOp.SUB -> left.value - right.value + ArithmeticOp.MUL -> left.value * right.value + ArithmeticOp.DIV -> if (right.value == 0L) null else left.value / right.value + ArithmeticOp.REM -> if (right.value == 0L) null else left.value % right.value + ArithmeticOp.BIT_AND -> left.value and right.value + ArithmeticOp.BIT_OR -> left.value or right.value + ArithmeticOp.BIT_XOR -> left.value xor right.value + // We can't simply convert `right.value` to Int because after conversion of quite large Long values + // (> 2^31 - 1) we can get any Int value including negative one, so it can lead to incorrect result. + // But if `rightValue` >= `java.lang.Long.BYTES` we know result without computation: + // overflow in 'shl' case and 0 in 'shr' case. + ArithmeticOp.SHL -> if (right.value >= java.lang.Long.BYTES) null else left.value shl right.value.toInt() + ArithmeticOp.SHR -> if (right.value >= java.lang.Long.BYTES) 0 else left.value shr right.value.toInt() + else -> return ConstExpr.Error() + } + left is ConstExpr.Error || right is ConstExpr.Error -> + return ConstExpr.Error() + else -> + return expr.copy(left = left, right = right) + } + + } + else -> return expr + } + val checkedValue = value?.validValueOrNull(expectedTy) ?: return ConstExpr.Error() + @Suppress("UNCHECKED_CAST") + return ConstExpr.Value.Integer(checkedValue, expectedTy) as ConstExpr +} + +private fun Long.validValueOrNull(ty: TyInteger): Long? = takeIf { it in ty.validValuesRange } + +// It returns wrong values for large types like `i128` or `usize`, but looks like it's enough for real cases +private val TyInteger.validValuesRange: LongRange + get() = when (this) { + TyInteger.U8 -> LongRange(0, 1L shl 8) + TyInteger.U16 -> LongRange(0, 1L shl 16) + TyInteger.U32 -> LongRange(0, 1L shl 32) + TyInteger.U64 -> LongRange(0, Long.MAX_VALUE) + TyInteger.U128 -> LongRange(0, Long.MAX_VALUE) + TyInteger.USize -> LongRange(0, Long.MAX_VALUE) + TyInteger.I8 -> LongRange(-(1L shl 7), (1L shl 7) - 1) + TyInteger.I16 -> LongRange(-(1L shl 15), (1L shl 15) - 1) + TyInteger.I32 -> LongRange(-(1L shl 31), (1L shl 31) - 1) + TyInteger.I64 -> LongRange(Long.MIN_VALUE, Long.MAX_VALUE) + TyInteger.I128 -> LongRange(Long.MIN_VALUE, Long.MAX_VALUE) + TyInteger.ISize -> LongRange(Long.MIN_VALUE, Long.MAX_VALUE) + } diff --git a/src/main/kotlin/org/rust/lang/utils/evaluation/ExprValue.kt b/src/main/kotlin/org/rust/lang/utils/evaluation/ExprValue.kt deleted file mode 100644 index 2573c8b9e89..00000000000 --- a/src/main/kotlin/org/rust/lang/utils/evaluation/ExprValue.kt +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Use of this source code is governed by the MIT license that can be - * found in the LICENSE file. - */ - -package org.rust.lang.utils.evaluation - -sealed class ExprValue { - data class Bool(val value: Boolean) : ExprValue() { - override fun toString(): String = value.toString() - } - - data class Integer(val value: Long) : ExprValue() { - override fun toString(): String = value.toString() - } - - data class Float(val value: Double) : ExprValue() { - override fun toString(): String = value.toString() - } - - data class Str(val value: String) : ExprValue() { - override fun toString(): String = value - } - - data class Char(val value: String) : ExprValue() { - override fun toString(): String = value - } -} diff --git a/src/main/kotlin/org/rust/lang/utils/evaluation/RsConstExprEvaluator.kt b/src/main/kotlin/org/rust/lang/utils/evaluation/RsConstExprEvaluator.kt deleted file mode 100644 index 1ad7274a838..00000000000 --- a/src/main/kotlin/org/rust/lang/utils/evaluation/RsConstExprEvaluator.kt +++ /dev/null @@ -1,181 +0,0 @@ -/* - * Use of this source code is governed by the MIT license that can be - * found in the LICENSE file. - */ - -package org.rust.lang.utils.evaluation - -import org.rust.lang.core.psi.* -import org.rust.lang.core.psi.ext.* -import org.rust.lang.core.types.ty.* -import org.rust.lang.core.types.type - -private val STR_REF_TYPE: TyReference = TyReference(TyStr, Mutability.IMMUTABLE) - -object RsConstExprEvaluator { - - private val defaultExprPathResolver: (RsPathExpr) -> RsElement? = { it.path.reference.resolve() } - - fun evaluate( - expr: RsExpr, - expectedTy: Ty = expr.type, - pathExprResolver: ((RsPathExpr) -> RsElement?)? = defaultExprPathResolver - ): ExprValue? { - val evaluation = when (expectedTy) { - is TyInteger -> IntegerExprEvaluation(expectedTy, pathExprResolver) - is TyBool -> BoolExprEvaluation(pathExprResolver) - is TyFloat -> FloatExprEvaluation(expectedTy, pathExprResolver) - is TyChar -> CharExprEvaluation(pathExprResolver) - // TODO: type should be "wider" - STR_REF_TYPE -> StrExprEvaluation(pathExprResolver) - else -> null - } - return evaluation?.evaluate(expr) - } -} - -private open class ExprEvaluation( - protected val expectedTy: T, - private val pathExprResolver: ((RsPathExpr) -> RsElement?)?, - private val evalLitExpr: RsLitExpr.() -> V?, - private val exprValueCtr: (V) -> ExprValue -) { - - fun evaluate(expr: RsExpr?): ExprValue? = evaluate(expr, 0)?.let(exprValueCtr) - - protected fun evaluate(expr: RsExpr?, depth: Int): V? { - // To prevent SO we restrict max depth of expression - if (depth >= MAX_EXPR_DEPTH) return null - return evaluateInner(expr, depth) - } - - protected open fun evaluateInner(expr: RsExpr?, depth: Int): V? { - return when (expr) { - is RsLitExpr -> expr.evalLitExpr() - is RsParenExpr -> evaluate(expr.expr, depth + 1) - is RsPathExpr -> { - val const = pathExprResolver?.invoke(expr) as? RsConstant ?: return null - if (!const.isConst) return null - val path = (const.typeReference?.typeElement as? RsBaseType)?.path ?: return null - if (TyPrimitive.fromPath(path) != expectedTy) return null - evaluate(const.expr, depth + 1) - } - else -> null - } - } - - companion object { - private const val MAX_EXPR_DEPTH: Int = 64 - } -} - -private class IntegerExprEvaluation( - expectedTy: TyInteger, - pathExprResolver: ((RsPathExpr) -> RsElement?)? -) : ExprEvaluation(expectedTy, pathExprResolver, RsLitExpr::integerValue, ExprValue::Integer) { - - override fun evaluateInner(expr: RsExpr?, depth: Int): Long? { - return when (expr) { - is RsUnaryExpr -> { - if (expr.operatorType != UnaryOperator.MINUS) return null - val value = evaluate(expr.expr, depth + 1) ?: return null - (-value).validValueOrNull(expectedTy) - } - is RsBinaryExpr -> { - val op = expr.operatorType as? ArithmeticOp ?: return null - val leftValue = evaluate(expr.left, depth + 1) ?: return null - val rightValue = evaluate(expr.right, depth + 1) ?: return null - // TODO: check overflow - val result = when (op) { - ArithmeticOp.ADD -> leftValue + rightValue - ArithmeticOp.SUB -> leftValue - rightValue - ArithmeticOp.MUL -> leftValue * rightValue - ArithmeticOp.DIV -> if (rightValue == 0L) null else leftValue / rightValue - ArithmeticOp.REM -> if (rightValue == 0L) null else leftValue % rightValue - ArithmeticOp.BIT_AND -> leftValue and rightValue - ArithmeticOp.BIT_OR -> leftValue or rightValue - ArithmeticOp.BIT_XOR -> leftValue xor rightValue - // We can't simply convert `rightValue` to Int - // because after conversion of quite large Long values (> 2^31 - 1) - // we can get any Int value including negative one - // so it can lead to incorrect result. - // But if `rightValue` >= `java.lang.Long.BYTES` - // we know result without computation: - // overflow in 'shl' case and 0 in 'shr' case. - ArithmeticOp.SHL -> if (rightValue >= java.lang.Long.BYTES) null else leftValue shl rightValue.toInt() - ArithmeticOp.SHR -> if (rightValue >= java.lang.Long.BYTES) 0 else leftValue shr rightValue.toInt() - } - result?.validValueOrNull(expectedTy) - } - else -> super.evaluateInner(expr, depth) - } - } - - // It returns wrong values for large types like `i128` or `usize` - // But looks like like it's enough for real cases - private val TyInteger.validValuesRange: LongRange - get() = when (this) { - TyInteger.U8 -> LongRange(0, 1L shl 8) - TyInteger.U16 -> LongRange(0, 1L shl 16) - TyInteger.U32 -> LongRange(0, 1L shl 32) - TyInteger.U64 -> LongRange(0, Long.MAX_VALUE) - TyInteger.U128 -> LongRange(0, Long.MAX_VALUE) - TyInteger.USize -> LongRange(0, Long.MAX_VALUE) - TyInteger.I8 -> LongRange(-(1L shl 7), (1L shl 7) - 1) - TyInteger.I16 -> LongRange(-(1L shl 15), (1L shl 15) - 1) - TyInteger.I32 -> LongRange(-(1L shl 31), (1L shl 31) - 1) - TyInteger.I64 -> LongRange(Long.MIN_VALUE, Long.MAX_VALUE) - TyInteger.I128 -> LongRange(Long.MIN_VALUE, Long.MAX_VALUE) - TyInteger.ISize -> LongRange(Long.MIN_VALUE, Long.MAX_VALUE) - } - - private fun Long.validValueOrNull(ty: TyInteger): Long? = if (this in ty.validValuesRange) this else null -} - -private class BoolExprEvaluation( - pathExprResolver: ((RsPathExpr) -> RsElement?)? -) : ExprEvaluation(TyBool, pathExprResolver, RsLitExpr::booleanValue, ExprValue::Bool) { - - override fun evaluateInner(expr: RsExpr?, depth: Int): Boolean? { - return when (expr) { - is RsBinaryExpr -> when (expr.operatorType) { - LogicOp.AND -> { - val lhs = evaluate(expr.left, depth + 1) ?: return null - if (!lhs) return false // false && _ --> false - val rhs = evaluate(expr.right, depth + 1) ?: return null - lhs && rhs - } - LogicOp.OR -> { - val lhs = evaluate(expr.left, depth + 1) ?: return null - if (lhs) return true // true || _ --> true - val rhs = evaluate(expr.right, depth + 1) ?: return null - lhs || rhs - } - ArithmeticOp.BIT_XOR -> { - val lhs = evaluate(expr.left, depth + 1) ?: return null - val rhs = evaluate(expr.right, depth + 1) ?: return null - lhs xor rhs - } - else -> null - } - is RsUnaryExpr -> when (expr.operatorType) { - UnaryOperator.NOT -> evaluate(expr.expr, depth + 1)?.let { !it } - else -> null - } - else -> super.evaluateInner(expr, depth) - } - } -} - -private class FloatExprEvaluation( - expectedTy: TyFloat, - pathExprResolver: ((RsPathExpr) -> RsElement?)? -) : ExprEvaluation(expectedTy, pathExprResolver, RsLitExpr::floatValue, ExprValue::Float) - -private class CharExprEvaluation( - pathExprResolver: ((RsPathExpr) -> RsElement?)? -) : ExprEvaluation(TyChar, pathExprResolver, RsLitExpr::charValue, ExprValue::Char) - -private class StrExprEvaluation( - pathExprResolver: ((RsPathExpr) -> RsElement?)? -) : ExprEvaluation(STR_REF_TYPE, pathExprResolver, RsLitExpr::stringValue, ExprValue::Str) diff --git a/src/main/kotlin/org/rust/lang/utils/evaluation/Utils.kt b/src/main/kotlin/org/rust/lang/utils/evaluation/Utils.kt new file mode 100644 index 00000000000..f585a84a201 --- /dev/null +++ b/src/main/kotlin/org/rust/lang/utils/evaluation/Utils.kt @@ -0,0 +1,19 @@ +/* + * Use of this source code is governed by the MIT license that can be + * found in the LICENSE file. + */ + +package org.rust.lang.utils.evaluation + +import org.rust.lang.core.psi.RsPathExpr +import org.rust.lang.core.psi.ext.RsElement +import org.rust.lang.core.types.infer.RsInferenceContext + +class PathExprResolver(resolver: (RsPathExpr) -> RsElement?): (RsPathExpr) -> RsElement? by resolver { + companion object { + val default: PathExprResolver = PathExprResolver { it.path.reference.resolve() } + + fun fromContext(ctx: RsInferenceContext): PathExprResolver = + PathExprResolver { ctx.getResolvedPath(it).singleOrNull()?.element } + } +} From 7c168be8ab47c602d355fb69f6e06cb5487ead65 Mon Sep 17 00:00:00 2001 From: Mikhail Chernyavsky Date: Wed, 25 Dec 2019 18:07:33 +0300 Subject: [PATCH 3/3] T: Initial type inference for const generics --- .../SpecifyTypeExplicitlyIntentionTest.kt | 16 +- .../ImplementMembersHandlerTest.kt | 98 ++++++++++ .../core/completion/RsLookupElementTest.kt | 17 ++ .../lang/core/resolve/RsResolveCacheTest.kt | 12 ++ .../core/resolve/RsStubOnlyResolveTest.kt | 132 +++++++++++++ .../core/type/RsStubOnlyTypeInferenceTest.kt | 185 ++++++++++++++++++ 6 files changed, 457 insertions(+), 3 deletions(-) diff --git a/src/test/kotlin/org/rust/ide/intentions/SpecifyTypeExplicitlyIntentionTest.kt b/src/test/kotlin/org/rust/ide/intentions/SpecifyTypeExplicitlyIntentionTest.kt index 2406ab47ce9..a73a5ed07c0 100644 --- a/src/test/kotlin/org/rust/ide/intentions/SpecifyTypeExplicitlyIntentionTest.kt +++ b/src/test/kotlin/org/rust/ide/intentions/SpecifyTypeExplicitlyIntentionTest.kt @@ -15,9 +15,15 @@ class SpecifyTypeExplicitlyIntentionTest : RsIntentionTestBase(SpecifyTypeExplic ) fun `test generic type`() = doAvailableTest( - """struct A(T);fn main() { let var/*caret*/ = A(42); } """, - """struct A(T);fn main() { let var: A = A(42); } """ + """struct A(T); fn main() { let var/*caret*/ = A(42); } """, + """struct A(T); fn main() { let var: A = A(42); } """ ) + + fun `test type with const generic`() = doAvailableTest( + """struct A; fn main() { let var/*caret*/ = A::<1>; } """, + """struct A; fn main() { let var: A<1> = A::<1>; } """ + ) + fun `test complex pattern`() = doAvailableTest( """ fn main() { let (a, b)/*caret*/ = (1, 2); } """, """ fn main() { let (a, b): (i32, i32) = (1, 2); } """ @@ -52,7 +58,11 @@ class SpecifyTypeExplicitlyIntentionTest : RsIntentionTestBase(SpecifyTypeExplic ) fun `test generic type with not inferred type`() = doUnavailableTest( - """struct A(T);fn main() { let var/*caret*/ = A(a); } """ + """struct A(T); fn main() { let var/*caret*/ = A(a); } """ + ) + + fun `test generic type with not inferred const generic`() = doUnavailableTest( + """struct A; fn main() { let var/*caret*/ = A; } """ ) fun `test anon type`() = doUnavailableTest(""" diff --git a/src/test/kotlin/org/rust/ide/refactoring/implementMembers/ImplementMembersHandlerTest.kt b/src/test/kotlin/org/rust/ide/refactoring/implementMembers/ImplementMembersHandlerTest.kt index 7f9c9022dba..55d73cf68fd 100644 --- a/src/test/kotlin/org/rust/ide/refactoring/implementMembers/ImplementMembersHandlerTest.kt +++ b/src/test/kotlin/org/rust/ide/refactoring/implementMembers/ImplementMembersHandlerTest.kt @@ -561,6 +561,104 @@ class ImplementMembersHandlerTest : RsTestBase() { } """) + fun `test implement generic trait with consts 1`() = doTest(""" + struct S; + trait T { + fn f1(_: S<{ M }>) -> S<{ M }>; + const C1: S<{ M }>; + fn f2(_: S<{ UNKNOWN }>) -> S<{ UNKNOWN }>; + const C2: S<{ UNKNOWN }>; + fn f3(_: [i32; M]) -> [i32; M]; + const C3: [i32; M]; + fn f4(_: [i32; UNKNOWN]) -> [i32; UNKNOWN]; + const C4: [i32; UNKNOWN]; + } + impl T<1> for S<1> {/*caret*/} + """, listOf( + ImplementMemberSelection("f1(_: S<{ M }>) -> S<{ M }>", true), + ImplementMemberSelection("C1: S<{ M }>", true), + ImplementMemberSelection("f2(_: S<{ UNKNOWN }>) -> S<{ UNKNOWN }>", true), + ImplementMemberSelection("C2: S<{ UNKNOWN }>", true), + ImplementMemberSelection("f3(_: [i32; M]) -> [i32; M]", true), + ImplementMemberSelection("C3: [i32; M]", true), + ImplementMemberSelection("f4(_: [i32; UNKNOWN]) -> [i32; UNKNOWN]", true), + ImplementMemberSelection("C4: [i32; UNKNOWN]", true) + ), """ + struct S; + trait T { + fn f1(_: S<{ M }>) -> S<{ M }>; + const C1: S<{ M }>; + fn f2(_: S<{ UNKNOWN }>) -> S<{ UNKNOWN }>; + const C2: S<{ UNKNOWN }>; + fn f3(_: [i32; M]) -> [i32; M]; + const C3: [i32; M]; + fn f4(_: [i32; UNKNOWN]) -> [i32; UNKNOWN]; + const C4: [i32; UNKNOWN]; + } + impl T<1> for S<1> { + fn f1(_: S<1>) -> S<1> { + unimplemented!() + } + + const C1: S<1> = S; + + fn f2(_: S<{}>) -> S<{}> { + unimplemented!() + } + + const C2: S<{}> = S; + + fn f3(_: [i32; 1]) -> [i32; 1] { + unimplemented!() + } + + const C3: [i32; 1] = []; + + fn f4(_: [i32; {}]) -> [i32; {}] { + unimplemented!() + } + + const C4: [i32; {}] = []; + } + """) + + fun `test implement generic trait with consts 2`() = doTest(""" + struct S; + trait T { + fn f1(_: S<{ M }>) -> S<{ M }>; + const C1: S<{ M }>; + fn f2(_: [i32; M]) -> [i32; M]; + const C2: [i32; M]; + } + impl T<{ K }> for S<{ K }> {/*caret*/} + """, listOf( + ImplementMemberSelection("f1(_: S<{ M }>) -> S<{ M }>", true), + ImplementMemberSelection("C1: S<{ M }>", true), + ImplementMemberSelection("f2(_: [i32; M]) -> [i32; M]", true), + ImplementMemberSelection("C2: [i32; M]", true) + ), """ + struct S; + trait T { + fn f1(_: S<{ M }>) -> S<{ M }>; + const C1: S<{ M }>; + fn f2(_: [i32; M]) -> [i32; M]; + const C2: [i32; M]; + } + impl T<{ K }> for S<{ K }> { + fn f1(_: S<{ K }>) -> S<{ K }> { + unimplemented!() + } + + const C1: S<{ K }> = S; + + fn f2(_: [i32; K]) -> [i32; K] { + unimplemented!() + } + + const C2: [i32; K] = []; + } + """) + fun `test do not implement methods already present`() = doTest(""" trait T { fn f1(); diff --git a/src/test/kotlin/org/rust/lang/core/completion/RsLookupElementTest.kt b/src/test/kotlin/org/rust/lang/core/completion/RsLookupElementTest.kt index dfcdbe54cda..b64e56993bf 100644 --- a/src/test/kotlin/org/rust/lang/core/completion/RsLookupElementTest.kt +++ b/src/test/kotlin/org/rust/lang/core/completion/RsLookupElementTest.kt @@ -279,6 +279,23 @@ class RsLookupElementTest : RsTestBase() { } """, tailText = "(x: i32) of T", typeText = "i32") + fun `test const generic function`() = checkProvider(""" + struct S(i32); + + trait T { + fn foo(); + } + + impl T<{ K }> for S<{ K }> { + fn foo() {} + } + + fn main() { + S::<1>::foo; + //^ + } + """, tailText = "() of T<1>", typeText = "()") + private fun check( @Language("Rust") code: String, tailText: String? = null, diff --git a/src/test/kotlin/org/rust/lang/core/resolve/RsResolveCacheTest.kt b/src/test/kotlin/org/rust/lang/core/resolve/RsResolveCacheTest.kt index 5a7b9eb1fd9..93b5222f781 100644 --- a/src/test/kotlin/org/rust/lang/core/resolve/RsResolveCacheTest.kt +++ b/src/test/kotlin/org/rust/lang/core/resolve/RsResolveCacheTest.kt @@ -207,6 +207,18 @@ class RsResolveCacheTest : RsTestBase() { //^ """, "\b2") + fun `test edit const expr`() = checkResolvedToXY(""" + struct S; + fn foo< + const N1: usize, + //X + const N2: usize + //Y + >() { + let _: S<{ N1/*caret*/ }>; + } //^ + """, "\b2") + private fun checkResolvedToXY(@Language("Rust") code: String, textToType: String) { InlineFile(code).withCaret() diff --git a/src/test/kotlin/org/rust/lang/core/resolve/RsStubOnlyResolveTest.kt b/src/test/kotlin/org/rust/lang/core/resolve/RsStubOnlyResolveTest.kt index f6aa4faf1b9..a7af429b99e 100644 --- a/src/test/kotlin/org/rust/lang/core/resolve/RsStubOnlyResolveTest.kt +++ b/src/test/kotlin/org/rust/lang/core/resolve/RsStubOnlyResolveTest.kt @@ -581,4 +581,136 @@ class RsStubOnlyResolveTest : RsResolveTestBase() { } } """) + + fun `test trait impl for const generic 1`() = stubOnlyResolve(""" + //- main.rs + #![feature(const_generics)] + mod bar; + use bar::T; + fn main() { + let s = [0]; + s.foo() + //^ bar.rs + } + + //- bar.rs + pub trait T { + fn foo(&self); + } + + impl T for [i32; 1] { + fn foo(&self) {} + } + """) + + fun `test trait impl for const generic 2`() = stubOnlyResolve(""" + //- main.rs + #![feature(const_generics)] + mod bar; + use bar::T; + fn main() { + let s = [0, 1]; + s.foo() + //^ unresolved + } + + //- bar.rs + pub trait T { + fn foo(&self); + } + + impl T for [i32; 1] { + fn foo(&self) {} + } + """) + + fun `test trait impl for const generic 3`() = stubOnlyResolve(""" + //- main.rs + #![feature(const_generics)] + mod bar; + use bar::T; + fn main() { + let s = [0]; + s.foo() + //^ bar.rs + } + + //- bar.rs + pub trait T { + fn foo(&self); + } + + impl T for [i32; N] { + fn foo(&self) {} + } + """) + + fun `test trait impl const generic 4`() = stubOnlyResolve(""" + //- main.rs + #![feature(const_generics)] + mod bar; + use bar::T; + fn main() { + let s = bar::S::<0>; + s.foo() + //^ bar.rs + } + + //- bar.rs + pub struct S; + + pub trait T { + fn foo(&self); + } + + impl T for S<{ 0 }> { + fn foo(&self) {} + } + """) + + fun `test trait impl const generic 5`() = stubOnlyResolve(""" + //- main.rs + #![feature(const_generics)] + mod bar; + use bar::T; + fn main() { + let s = bar::S::<1>; + s.foo() + //^ unresolved + } + + //- bar.rs + pub struct S; + + pub trait T { + fn foo(&self); + } + + impl T for S<{ 0 }> { + fn foo(&self) {} + } + """) + + fun `test trait impl const generic 6`() = stubOnlyResolve(""" + //- main.rs + #![feature(const_generics)] + mod bar; + use bar::T; + fn main() { + let s = bar::S::<0>; + s.foo() + //^ bar.rs + } + + //- bar.rs + pub struct S; + + pub trait T { + fn foo(&self); + } + + impl T for S<{ N }> { + fn foo(&self) {} + } + """) } diff --git a/src/test/kotlin/org/rust/lang/core/type/RsStubOnlyTypeInferenceTest.kt b/src/test/kotlin/org/rust/lang/core/type/RsStubOnlyTypeInferenceTest.kt index ed100c69655..42082e56681 100644 --- a/src/test/kotlin/org/rust/lang/core/type/RsStubOnlyTypeInferenceTest.kt +++ b/src/test/kotlin/org/rust/lang/core/type/RsStubOnlyTypeInferenceTest.kt @@ -52,4 +52,189 @@ class RsStubOnlyTypeInferenceTest : RsTypificationTestBase() { //^ [i32; 2] } """) + + fun `test const generic in path (explicit)`() = stubOnlyTypeInfer(""" + //- foo.rs + #![feature(const_generics)] + + pub struct S; + + pub fn foo() -> S<{ N2 }> { S } + + //- main.rs + mod foo; + + fn main() { + let x = foo::foo::<0>(); + x; + //^ S<0> + } + """) + + fun `test const generic in path (implicit)`() = stubOnlyTypeInfer(""" + //- foo.rs + #![feature(const_generics)] + + #[derive(Clone, Copy)] + pub struct S; + + pub fn foo() -> S<{ N2 }> { S } + pub fn bar(s: S<{ N3 }>) -> S<{ N3 }> { s } + + //- main.rs + mod foo; + + fn main() { + let x = foo::foo(); + let y: foo::S<0> = foo::bar(x); + x; + //^ S<0> + } + """) + + fun `test const generic in path (not inferred)`() = stubOnlyTypeInfer(""" + //- foo.rs + #![feature(const_generics)] + + #[derive(Clone, Copy)] + pub struct S; + + pub fn foo() -> S<{ N2 }> { S } + pub fn bar(s: S<{ N3 }>) -> S<{ N3 }> { s } + + //- main.rs + mod foo; + + fn main() { + let x = foo::foo(); + let y = foo::bar(x); + x; + //^ S<> + } + """) + + fun `test const generic (block expr) 1`() = stubOnlyTypeInfer(""" + //- foo.rs + #![feature(const_generics)] + + pub struct S; + + //- main.rs + mod foo; + + fn main() { + let x = foo::S::<{ 0 }>; + x; + //^ S<0> + } + """) + + fun `test const generic (block expr) 2`() = stubOnlyTypeInfer(""" + //- foo.rs + #![feature(const_generics)] + + pub struct S; + + //- main.rs + mod foo; + + fn main() { + let x = foo::S::<{ 1 + 1 }>; + x; + //^ S<2> + } + """) + + fun `test const generic (block expr) 3`() = stubOnlyTypeInfer(""" + //- foo.rs + #![feature(const_generics)] + + pub struct S; + + pub fn add1() -> S<{ N + 1 }> { S } + + //- main.rs + mod foo; + + fn main() { + let x = foo::add1::<1>(); + x; + //^ S<2> + } + """) + + fun `test const generic in method call (explicit)`() = stubOnlyTypeInfer(""" + //- foo.rs + #![feature(const_generics)] + + pub struct S1; + pub struct S2; + + impl S1<{ N3 }> { + pub fn foo(&self) -> S2<{ N3 }, { M3 }> { S2 } + } + + //- main.rs + mod foo; + + fn main() { + let x = foo::S1::<0>.foo::<1usize>(); + x; + //^ S2<0, 1> + } + """) + + fun `test const generic in method call (implicit)`() = stubOnlyTypeInfer(""" + //- foo.rs + #![feature(const_generics)] + + pub struct S1; + + #[derive(Clone, Copy)] + pub struct S2; + + impl S1<{ N3 }> { + pub fn foo(&self) -> S2<{ N3 }, { M3 }> { S2 } + } + + pub fn bar(s: S2<{ N4 }, { M4 }>) -> S2<{ N4 }, { M4 }> { s } + + //- main.rs + mod foo; + + fn main() { + let x = foo::S1.foo(); + let y: foo::S2<0, 1> = foo::bar(x); + x; + //^ S2<0, 1> + } + """) + + fun `test const generic in base type`() = stubOnlyTypeInfer(""" + //- foo.rs + #![feature(const_generics)] + + pub struct S; + + //- mod.rs + mod foo; + + fn bar() -> foo::S<{ N2 }> { foo::S } + //^ usize + """) + + fun `test const generic in trait ref`() = stubOnlyTypeInfer(""" + //- foo.rs + #![feature(const_generics)] + + pub struct S; + + pub trait T {} + + //- mod.rs + mod foo; + + impl foo::T<{ N2 }> for foo::S {} + //^ usize + """) }