diff --git a/src/main/grammars/RustParser.bnf b/src/main/grammars/RustParser.bnf index 48302d13b43..05387b1d53e 100644 --- a/src/main/grammars/RustParser.bnf +++ b/src/main/grammars/RustParser.bnf @@ -240,6 +240,7 @@ ValuePathGenericArgsNoTypeQual ::= < { val declaration = resolved as RsGenericDeclaration diff --git a/src/main/kotlin/org/rust/ide/inspections/RsWrongLifetimeParametersNumberInspection.kt b/src/main/kotlin/org/rust/ide/inspections/RsWrongLifetimeParametersNumberInspection.kt index 40f391fc3d2..c17c8ef738f 100644 --- a/src/main/kotlin/org/rust/ide/inspections/RsWrongLifetimeParametersNumberInspection.kt +++ b/src/main/kotlin/org/rust/ide/inspections/RsWrongLifetimeParametersNumberInspection.kt @@ -9,6 +9,8 @@ import org.rust.lang.core.psi.RsBaseType import org.rust.lang.core.psi.RsRefLikeType import org.rust.lang.core.psi.RsVisitor import org.rust.lang.core.psi.ext.RsGenericDeclaration +import org.rust.lang.core.psi.ext.lifetimeArguments +import org.rust.lang.core.psi.ext.lifetimeParameters import org.rust.lang.core.types.lifetimeElidable import org.rust.lang.utils.RsDiagnostic import org.rust.lang.utils.addToHolder @@ -23,8 +25,8 @@ class RsWrongLifetimeParametersNumberInspection : RsLocalInspectionTool() { if (type.path?.cself != null) return val paramsDecl = type.path?.reference?.resolve() as? RsGenericDeclaration ?: return - val expectedLifetimes = paramsDecl.typeParameterList?.lifetimeParameterList?.size ?: 0 - val actualLifetimes = type.path?.typeArgumentList?.lifetimeList?.size ?: 0 + val expectedLifetimes = paramsDecl.lifetimeParameters.size + val actualLifetimes = type.path?.lifetimeArguments?.size ?: 0 if (expectedLifetimes == actualLifetimes) return if (actualLifetimes == 0 && !type.lifetimeElidable) { RsDiagnostic.MissingLifetimeSpecifier(type).addToHolder(holder) diff --git a/src/main/kotlin/org/rust/ide/refactoring/extractFunction/RsExtractFunctionConfig.kt b/src/main/kotlin/org/rust/ide/refactoring/extractFunction/RsExtractFunctionConfig.kt index bb518968ba2..104de900d09 100644 --- a/src/main/kotlin/org/rust/ide/refactoring/extractFunction/RsExtractFunctionConfig.kt +++ b/src/main/kotlin/org/rust/ide/refactoring/extractFunction/RsExtractFunctionConfig.kt @@ -173,7 +173,7 @@ class RsExtractFunctionConfig private constructor( val type = it.declaredType val bounds = mutableSetOf() it.bounds.flatMapTo(bounds) { - it.bound.traitRef?.path?.typeArgumentList?.typeReferenceList?.flatMap { it.type.types() } ?: emptyList() + it.bound.traitRef?.path?.typeArguments?.flatMap { it.type.types() }.orEmpty() } type to bounds } diff --git a/src/main/kotlin/org/rust/lang/core/psi/ext/RsGenericDeclaration.kt b/src/main/kotlin/org/rust/lang/core/psi/ext/RsGenericDeclaration.kt index 3a4f8f2c887..919baded888 100644 --- a/src/main/kotlin/org/rust/lang/core/psi/ext/RsGenericDeclaration.kt +++ b/src/main/kotlin/org/rust/lang/core/psi/ext/RsGenericDeclaration.kt @@ -15,8 +15,8 @@ interface RsGenericDeclaration : RsElement { val RsGenericDeclaration.typeParameters: List get() = typeParameterList?.typeParameterList.orEmpty() -val RsGenericDeclaration.constParameters: List - get() = typeParameterList?.constParameterList.orEmpty() - val RsGenericDeclaration.lifetimeParameters: List get() = typeParameterList?.lifetimeParameterList.orEmpty() + +val RsGenericDeclaration.constParameters: List + get() = typeParameterList?.constParameterList.orEmpty() diff --git a/src/main/kotlin/org/rust/lang/core/psi/ext/RsInferenceContextOwner.kt b/src/main/kotlin/org/rust/lang/core/psi/ext/RsInferenceContextOwner.kt index 1e6f920b217..56b8238f653 100644 --- a/src/main/kotlin/org/rust/lang/core/psi/ext/RsInferenceContextOwner.kt +++ b/src/main/kotlin/org/rust/lang/core/psi/ext/RsInferenceContextOwner.kt @@ -22,5 +22,7 @@ val RsInferenceContextOwner.body: RsElement? is RsFunction -> block is RsVariantDiscriminant -> expr is RsExpressionCodeFragment -> expr + is RsBaseType -> path?.typeArgumentList + is RsTraitRef -> path.typeArgumentList else -> null } diff --git a/src/main/kotlin/org/rust/lang/core/psi/ext/RsMethodCall.kt b/src/main/kotlin/org/rust/lang/core/psi/ext/RsMethodCall.kt index 4d852ee9692..77bc3645f5b 100644 --- a/src/main/kotlin/org/rust/lang/core/psi/ext/RsMethodCall.kt +++ b/src/main/kotlin/org/rust/lang/core/psi/ext/RsMethodCall.kt @@ -8,11 +8,20 @@ package org.rust.lang.core.psi.ext import com.intellij.lang.ASTNode import com.intellij.openapi.util.TextRange import com.intellij.psi.PsiElement +import org.rust.lang.core.psi.RsExpr +import org.rust.lang.core.psi.RsLifetime import org.rust.lang.core.psi.RsMethodCall +import org.rust.lang.core.psi.RsTypeReference import org.rust.lang.core.resolve.ref.RsMethodCallReferenceImpl import org.rust.lang.core.resolve.ref.RsReference -val RsMethodCall.textRangeWithoutValueArguments +val RsMethodCall.lifetimeArguments: List get() = typeArgumentList?.lifetimeList.orEmpty() + +val RsMethodCall.typeArguments: List get() = typeArgumentList?.typeReferenceList.orEmpty() + +val RsMethodCall.constArguments: List get() = typeArgumentList?.exprList.orEmpty() + +val RsMethodCall.textRangeWithoutValueArguments: TextRange get() = TextRange(startOffset, typeArgumentList?.endOffset ?: identifier.endOffset) abstract class RsMethodCallImplMixin(node: ASTNode) : RsElementImpl(node), RsMethodCall { diff --git a/src/main/kotlin/org/rust/lang/core/psi/ext/RsPath.kt b/src/main/kotlin/org/rust/lang/core/psi/ext/RsPath.kt index c93a03eaca6..6b0e042b189 100644 --- a/src/main/kotlin/org/rust/lang/core/psi/ext/RsPath.kt +++ b/src/main/kotlin/org/rust/lang/core/psi/ext/RsPath.kt @@ -81,6 +81,12 @@ fun RsPath.allowedNamespaces(isCompletion: Boolean = false): Set = wh else -> TYPES_N_VALUES } +val RsPath.lifetimeArguments: List get() = typeArgumentList?.lifetimeList.orEmpty() + +val RsPath.typeArguments: List get() = typeArgumentList?.typeReferenceList.orEmpty() + +val RsPath.constArguments: List get() = typeArgumentList?.exprList.orEmpty() + abstract class RsPathImplMixin : RsStubbedElementImpl, RsPath { constructor(node: ASTNode) : super(node) diff --git a/src/main/kotlin/org/rust/lang/core/psi/ext/RsTraitItem.kt b/src/main/kotlin/org/rust/lang/core/psi/ext/RsTraitItem.kt index 188c4514c8a..57b0da5ebd9 100644 --- a/src/main/kotlin/org/rust/lang/core/psi/ext/RsTraitItem.kt +++ b/src/main/kotlin/org/rust/lang/core/psi/ext/RsTraitItem.kt @@ -95,7 +95,7 @@ val RsTraitItem.isSized: Boolean get() { } fun RsTraitItem.withSubst(vararg subst: Ty): BoundElement { - val typeParameterList = typeParameterList?.typeParameterList.orEmpty() + val typeParameterList = typeParameters val substitution = if (typeParameterList.size != subst.size) { LOG.warn("Trait has ${typeParameterList.size} type parameters but received ${subst.size} types for substitution") emptySubstitution diff --git a/src/main/kotlin/org/rust/lang/core/resolve/ImplLookup.kt b/src/main/kotlin/org/rust/lang/core/resolve/ImplLookup.kt index 7bf64938178..613cff12999 100644 --- a/src/main/kotlin/org/rust/lang/core/resolve/ImplLookup.kt +++ b/src/main/kotlin/org/rust/lang/core/resolve/ImplLookup.kt @@ -27,8 +27,7 @@ import org.rust.stdext.buildList import kotlin.LazyThreadSafetyMode.NONE private val RsTraitItem.typeParamSingle: TyTypeParameter? - get() = - typeParameterList?.typeParameterList?.singleOrNull()?.let { TyTypeParameter.named(it) } + get() = typeParameters.singleOrNull()?.let { TyTypeParameter.named(it) } const val DEFAULT_RECURSION_LIMIT = 64 diff --git a/src/main/kotlin/org/rust/lang/core/resolve/ref/RsPathReferenceImpl.kt b/src/main/kotlin/org/rust/lang/core/resolve/ref/RsPathReferenceImpl.kt index ee099a9fce6..4ea410992db 100644 --- a/src/main/kotlin/org/rust/lang/core/resolve/ref/RsPathReferenceImpl.kt +++ b/src/main/kotlin/org/rust/lang/core/resolve/ref/RsPathReferenceImpl.kt @@ -125,7 +125,8 @@ fun instantiatePathGenerics( resolved: BoundElement ): BoundElement { val (element, subst) = resolved.downcast() ?: return resolved - val typeArguments: List? = run { + + val typeArguments = run { val inAngles = path.typeArgumentList val fnSugar = path.valueParameterList when { @@ -136,7 +137,6 @@ fun instantiatePathGenerics( else -> null } } - val regionArguments: List? = path.typeArgumentList?.lifetimeList?.map { it.resolve() } val outputArg = path.retType?.typeReference?.type val assocTypes = run { @@ -186,7 +186,9 @@ fun instantiatePathGenerics( } paramTy to value } + val regionParameters = element.lifetimeParameters.map { ReEarlyBound(it) } + val regionArguments = path.typeArgumentList?.lifetimeList?.map { it.resolve() } val regionSubst = regionParameters.zip(regionArguments ?: regionParameters).toMap() val newSubst = Substitution(typeSubst, regionSubst) return BoundElement(resolved.element, subst + newSubst, assocTypes) diff --git a/src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt b/src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt index d970782b8b3..249bd419016 100644 --- a/src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt +++ b/src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt @@ -15,6 +15,7 @@ import org.rust.lang.core.psi.* import org.rust.lang.core.psi.ext.* import org.rust.lang.core.resolve.* import org.rust.lang.core.resolve.ref.MethodResolveVariant +import org.rust.lang.core.resolve.ref.resolvePathRaw import org.rust.lang.core.types.* import org.rust.lang.core.types.regions.Region import org.rust.lang.core.types.ty.* @@ -171,40 +172,53 @@ class RsInferenceContext( } fun infer(element: RsInferenceContextOwner): RsInferenceResult { - if (element is RsFunction) { - val fctx = RsTypeInferenceWalker(this, element.returnType) - fctx.extractParameterBindings(element) - element.block?.let { fctx.inferFnBody(it) } - } else if (element is RsReplCodeFragment) { - element.context.inference?.let { - patTypes.putAll(it.patTypes) - patFieldTypes.putAll(it.patFieldTypes) - exprTypes.putAll(it.exprTypes) + when (element) { + is RsFunction -> { + val fctx = RsTypeInferenceWalker(this, element.returnType) + fctx.extractParameterBindings(element) + element.block?.let { fctx.inferFnBody(it) } } - - val walker = RsTypeInferenceWalker(this, TyUnknown) - walker.inferReplCodeFragment(element) - } else { - val (retTy, expr) = when (element) { - is RsConstant -> element.typeReference?.type to element.expr - is RsArrayType -> TyInteger.USize to element.expr - is RsVariantDiscriminant -> { - val enum = element.ancestorStrict() - enum?.reprType to element.expr + is RsReplCodeFragment -> { + element.context.inference?.let { + patTypes.putAll(it.patTypes) + patFieldTypes.putAll(it.patFieldTypes) + exprTypes.putAll(it.exprTypes) } - is RsExpressionCodeFragment -> { - element.context.inference?.let { - patTypes.putAll(it.patTypes) - patFieldTypes.putAll(it.patFieldTypes) - exprTypes.putAll(it.exprTypes) - } - null to element.expr + RsTypeInferenceWalker(this, TyUnknown).inferReplCodeFragment(element) + } + is RsBaseType, is RsTraitRef -> { + val path = when (element) { + is RsBaseType -> element.path + is RsTraitRef -> element.path + else -> null } - else -> error("Type inference is not implemented for PSI element of type " + - "`${element.javaClass}` that implement `RsInferenceContextOwner`") + val declaration = path?.let { resolvePathRaw(it, lookup) }?.singleOrNull()?.element as? RsGenericDeclaration + val constParameters = declaration?.constParameters.orEmpty() + val constArguments = path?.constArguments.orEmpty() + RsTypeInferenceWalker(this, TyUnknown).inferConstArgumentTypes(constParameters, constArguments) } - if (expr != null) { - RsTypeInferenceWalker(this, retTy ?: TyUnknown).inferLambdaBody(expr) + else -> { + val (retTy, expr) = when (element) { + is RsConstant -> element.typeReference?.type to element.expr + is RsArrayType -> TyInteger.USize to element.expr + is RsVariantDiscriminant -> { + val enum = element.contextStrict() + enum?.reprType to element.expr + } + is RsExpressionCodeFragment -> { + element.context.inference?.let { + patTypes.putAll(it.patTypes) + patFieldTypes.putAll(it.patFieldTypes) + exprTypes.putAll(it.exprTypes) + } + null to element.expr + } + else -> error("Type inference is not implemented for PSI element of type " + + "`${element.javaClass}` that implement `RsInferenceContextOwner`") + } + if (expr != null) { + RsTypeInferenceWalker(this, retTy ?: TyUnknown).inferLambdaBody(expr) + } } } diff --git a/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt b/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt index 6756c2c2129..b302eac72fb 100644 --- a/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt +++ b/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt @@ -392,7 +392,13 @@ class RsTypeInferenceWalker( scopeEntry: ScopeEntry, pathExpr: RsPathExpr ): Ty { - val subst = instantiatePathGenerics(pathExpr.path, BoundElement(element, scopeEntry.subst)).subst + val path = pathExpr.path + val subst = instantiatePathGenerics(path, BoundElement(element, scopeEntry.subst)).subst + + if (element is RsGenericDeclaration) { + inferConstArgumentTypes(element.constParameters, path.constArguments) + } + val type = when (element) { is RsPatBinding -> ctx.getBindingType(element) is RsTypeDeclarationElement -> element.declaredType @@ -579,6 +585,8 @@ class RsTypeInferenceWalker( return methodType.retType } + inferConstArgumentTypes(callee.element.constParameters, methodCall.constArguments) + ctx.addDerefAdjustments(methodCall.receiver, callee.derefChain) if (callee.borrow != null) { ctx.addAdjustment(methodCall.receiver, Adjustment.BorrowReference(callee.methodSelfTy as TyReference)) @@ -731,6 +739,10 @@ class RsTypeInferenceWalker( } } + fun inferConstArgumentTypes(constParameters: List, constArguments: List) { + inferArgumentTypes(constParameters.map { it.typeReference?.type ?: TyUnknown }, constArguments) + } + private fun inferFieldExprType(receiver: Ty, fieldLookup: RsFieldLookup): Ty { if (fieldLookup.identifier?.text == "await" && fieldLookup.isEdition2018) { return receiver.lookupFutureOutputTy(lookup) diff --git a/src/test/kotlin/org/rust/ide/inspections/typecheck/RsTypeCheckInspectionTest.kt b/src/test/kotlin/org/rust/ide/inspections/typecheck/RsTypeCheckInspectionTest.kt index 46ac2f4b24f..bc4ec9c4c86 100644 --- a/src/test/kotlin/org/rust/ide/inspections/typecheck/RsTypeCheckInspectionTest.kt +++ b/src/test/kotlin/org/rust/ide/inspections/typecheck/RsTypeCheckInspectionTest.kt @@ -27,6 +27,21 @@ class RsTypeCheckInspectionTest : RsInspectionsTestBase(RsTypeCheckInspection::c const A: [u8; 1u8] = [0]; """) + fun `test typecheck in const argument`() = checkByText(""" + #![feature(const_generics)] + struct S; + trait T { + fn foo(&self) -> S<{ N }>; + } + impl T<1u8> for S<1u8> { + fn foo(self) -> S<1u8> { self } + } + fn bar(x: S<1u8>) -> S<1u8> { + let s: S<1u8> = S::<1u8>; + s.foo::<1u8>() + } + """) + fun `test typecheck in enum variant discriminant`() = checkByText(""" enum Foo { BAR = 1u8 } """)