From 0eee38c834633007c078e88a673f28c596a2e98f Mon Sep 17 00:00:00 2001 From: Mikhail Chernyavsky Date: Fri, 20 Dec 2019 17:16:09 +0300 Subject: [PATCH] TY: Check type of const generic argument --- src/main/grammars/RustParser.bnf | 4 +- .../RsDiagnosticBasedInspection.kt | 2 + .../core/psi/ext/RsInferenceContextOwner.kt | 2 + .../lang/core/types/infer/TypeInference.kt | 74 +++++++++++-------- .../core/types/infer/TypeInferenceWalker.kt | 14 +++- .../typecheck/RsTypeCheckInspectionTest.kt | 15 ++++ 6 files changed, 79 insertions(+), 32 deletions(-) diff --git a/src/main/grammars/RustParser.bnf b/src/main/grammars/RustParser.bnf index 48302d13b43..05387b1d53e 100644 --- a/src/main/grammars/RustParser.bnf +++ b/src/main/grammars/RustParser.bnf @@ -240,6 +240,7 @@ ValuePathGenericArgsNoTypeQual ::= < block is RsVariantDiscriminant -> expr is RsExpressionCodeFragment -> expr + is RsBaseType -> path?.typeArgumentList + is RsTraitRef -> path.typeArgumentList else -> null } diff --git a/src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt b/src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt index d970782b8b3..249bd419016 100644 --- a/src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt +++ b/src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt @@ -15,6 +15,7 @@ import org.rust.lang.core.psi.* import org.rust.lang.core.psi.ext.* import org.rust.lang.core.resolve.* import org.rust.lang.core.resolve.ref.MethodResolveVariant +import org.rust.lang.core.resolve.ref.resolvePathRaw import org.rust.lang.core.types.* import org.rust.lang.core.types.regions.Region import org.rust.lang.core.types.ty.* @@ -171,40 +172,53 @@ class RsInferenceContext( } fun infer(element: RsInferenceContextOwner): RsInferenceResult { - if (element is RsFunction) { - val fctx = RsTypeInferenceWalker(this, element.returnType) - fctx.extractParameterBindings(element) - element.block?.let { fctx.inferFnBody(it) } - } else if (element is RsReplCodeFragment) { - element.context.inference?.let { - patTypes.putAll(it.patTypes) - patFieldTypes.putAll(it.patFieldTypes) - exprTypes.putAll(it.exprTypes) + when (element) { + is RsFunction -> { + val fctx = RsTypeInferenceWalker(this, element.returnType) + fctx.extractParameterBindings(element) + element.block?.let { fctx.inferFnBody(it) } } - - val walker = RsTypeInferenceWalker(this, TyUnknown) - walker.inferReplCodeFragment(element) - } else { - val (retTy, expr) = when (element) { - is RsConstant -> element.typeReference?.type to element.expr - is RsArrayType -> TyInteger.USize to element.expr - is RsVariantDiscriminant -> { - val enum = element.ancestorStrict() - enum?.reprType to element.expr + is RsReplCodeFragment -> { + element.context.inference?.let { + patTypes.putAll(it.patTypes) + patFieldTypes.putAll(it.patFieldTypes) + exprTypes.putAll(it.exprTypes) } - is RsExpressionCodeFragment -> { - element.context.inference?.let { - patTypes.putAll(it.patTypes) - patFieldTypes.putAll(it.patFieldTypes) - exprTypes.putAll(it.exprTypes) - } - null to element.expr + RsTypeInferenceWalker(this, TyUnknown).inferReplCodeFragment(element) + } + is RsBaseType, is RsTraitRef -> { + val path = when (element) { + is RsBaseType -> element.path + is RsTraitRef -> element.path + else -> null } - else -> error("Type inference is not implemented for PSI element of type " + - "`${element.javaClass}` that implement `RsInferenceContextOwner`") + val declaration = path?.let { resolvePathRaw(it, lookup) }?.singleOrNull()?.element as? RsGenericDeclaration + val constParameters = declaration?.constParameters.orEmpty() + val constArguments = path?.constArguments.orEmpty() + RsTypeInferenceWalker(this, TyUnknown).inferConstArgumentTypes(constParameters, constArguments) } - if (expr != null) { - RsTypeInferenceWalker(this, retTy ?: TyUnknown).inferLambdaBody(expr) + else -> { + val (retTy, expr) = when (element) { + is RsConstant -> element.typeReference?.type to element.expr + is RsArrayType -> TyInteger.USize to element.expr + is RsVariantDiscriminant -> { + val enum = element.contextStrict() + enum?.reprType to element.expr + } + is RsExpressionCodeFragment -> { + element.context.inference?.let { + patTypes.putAll(it.patTypes) + patFieldTypes.putAll(it.patFieldTypes) + exprTypes.putAll(it.exprTypes) + } + null to element.expr + } + else -> error("Type inference is not implemented for PSI element of type " + + "`${element.javaClass}` that implement `RsInferenceContextOwner`") + } + if (expr != null) { + RsTypeInferenceWalker(this, retTy ?: TyUnknown).inferLambdaBody(expr) + } } } diff --git a/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt b/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt index 6756c2c2129..b302eac72fb 100644 --- a/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt +++ b/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt @@ -392,7 +392,13 @@ class RsTypeInferenceWalker( scopeEntry: ScopeEntry, pathExpr: RsPathExpr ): Ty { - val subst = instantiatePathGenerics(pathExpr.path, BoundElement(element, scopeEntry.subst)).subst + val path = pathExpr.path + val subst = instantiatePathGenerics(path, BoundElement(element, scopeEntry.subst)).subst + + if (element is RsGenericDeclaration) { + inferConstArgumentTypes(element.constParameters, path.constArguments) + } + val type = when (element) { is RsPatBinding -> ctx.getBindingType(element) is RsTypeDeclarationElement -> element.declaredType @@ -579,6 +585,8 @@ class RsTypeInferenceWalker( return methodType.retType } + inferConstArgumentTypes(callee.element.constParameters, methodCall.constArguments) + ctx.addDerefAdjustments(methodCall.receiver, callee.derefChain) if (callee.borrow != null) { ctx.addAdjustment(methodCall.receiver, Adjustment.BorrowReference(callee.methodSelfTy as TyReference)) @@ -731,6 +739,10 @@ class RsTypeInferenceWalker( } } + fun inferConstArgumentTypes(constParameters: List, constArguments: List) { + inferArgumentTypes(constParameters.map { it.typeReference?.type ?: TyUnknown }, constArguments) + } + private fun inferFieldExprType(receiver: Ty, fieldLookup: RsFieldLookup): Ty { if (fieldLookup.identifier?.text == "await" && fieldLookup.isEdition2018) { return receiver.lookupFutureOutputTy(lookup) diff --git a/src/test/kotlin/org/rust/ide/inspections/typecheck/RsTypeCheckInspectionTest.kt b/src/test/kotlin/org/rust/ide/inspections/typecheck/RsTypeCheckInspectionTest.kt index 46ac2f4b24f..bc4ec9c4c86 100644 --- a/src/test/kotlin/org/rust/ide/inspections/typecheck/RsTypeCheckInspectionTest.kt +++ b/src/test/kotlin/org/rust/ide/inspections/typecheck/RsTypeCheckInspectionTest.kt @@ -27,6 +27,21 @@ class RsTypeCheckInspectionTest : RsInspectionsTestBase(RsTypeCheckInspection::c const A: [u8; 1u8] = [0]; """) + fun `test typecheck in const argument`() = checkByText(""" + #![feature(const_generics)] + struct S; + trait T { + fn foo(&self) -> S<{ N }>; + } + impl T<1u8> for S<1u8> { + fn foo(self) -> S<1u8> { self } + } + fn bar(x: S<1u8>) -> S<1u8> { + let s: S<1u8> = S::<1u8>; + s.foo::<1u8>() + } + """) + fun `test typecheck in enum variant discriminant`() = checkByText(""" enum Foo { BAR = 1u8 } """)