Skip to content

Commit

Permalink
Merge #4677
Browse files Browse the repository at this point in the history
4677: TY: Check type of const generic argument r=vlad20012 a=mchernyavsky

Relates to #3985.

Co-authored-by: mchernyavsky <mikhail.chernyavsky@jetbrains.com>
Co-authored-by: Mikhail Chernyavsky <mikhail.chernyavsky@jetbrains.com>
  • Loading branch information
bors[bot] and mchernyavsky committed Dec 20, 2019
2 parents 6218207 + 0eee38c commit 0849b00
Show file tree
Hide file tree
Showing 15 changed files with 110 additions and 45 deletions.
4 changes: 3 additions & 1 deletion src/main/grammars/RustParser.bnf
Expand Up @@ -240,6 +240,7 @@ ValuePathGenericArgsNoTypeQual ::= <<typeQuals 'OFF' <<pathMode 'VALUE' PathImpl

// Path semantically constrained to resolve to a trait
TraitRef ::= TypePathGenericArgsNoTypeQual {
implements = "org.rust.lang.core.psi.ext.RsInferenceContextOwner"
extends = "org.rust.lang.core.psi.ext.RsStubbedElementImpl<?>"
stubClass = "org.rust.lang.core.stubs.RsPlaceholderStub"
elementTypeFactory = "org.rust.lang.core.stubs.StubImplementationsKt.factory"
Expand Down Expand Up @@ -891,7 +892,8 @@ ForInType ::= ForLifetimes (FnPointerType | TraitRef) {
}

BaseType ::= TrivialBaseTypeInner | TypePathGenericArgs {
implements = "org.rust.lang.core.psi.ext.RsTypeElement"
implements = [ "org.rust.lang.core.psi.ext.RsTypeElement"
"org.rust.lang.core.psi.ext.RsInferenceContextOwner" ]
extends = "org.rust.lang.core.psi.ext.RsStubbedElementImpl<?>"
stubClass = "org.rust.lang.core.stubs.RsBaseTypeStub"
elementTypeFactory = "org.rust.lang.core.stubs.StubImplementationsKt.factory"
Expand Down
Expand Up @@ -15,6 +15,8 @@ abstract class RsDiagnosticBasedInspection : RsLocalInspectionTool() {
override fun visitFunction(o: RsFunction) = collectDiagnostics(holder, o)
override fun visitConstant(o: RsConstant) = collectDiagnostics(holder, o)
override fun visitArrayType(o: RsArrayType) = collectDiagnostics(holder, o)
override fun visitBaseType(o: RsBaseType) = collectDiagnostics(holder, o)
override fun visitTraitRef(o: RsTraitRef) = collectDiagnostics(holder, o)
override fun visitVariantDiscriminant(o: RsVariantDiscriminant) = collectDiagnostics(holder, o)
}

Expand Down
Expand Up @@ -133,7 +133,7 @@ private class LifetimesCollector(val isForInputParams: Boolean = false) : RsVisi
}

private fun collectAnonymousLifetimes(path: RsPath) {
if (path.typeArgumentList?.lifetimeList.orEmpty().isNotEmpty()) return
if (path.lifetimeArguments.isNotEmpty()) return
when (val resolved = path.reference.resolve()) {
is RsStructItem, is RsTraitItem, is RsTypeAlias -> {
val declaration = resolved as RsGenericDeclaration
Expand Down
Expand Up @@ -9,6 +9,8 @@ import org.rust.lang.core.psi.RsBaseType
import org.rust.lang.core.psi.RsRefLikeType
import org.rust.lang.core.psi.RsVisitor
import org.rust.lang.core.psi.ext.RsGenericDeclaration
import org.rust.lang.core.psi.ext.lifetimeArguments
import org.rust.lang.core.psi.ext.lifetimeParameters
import org.rust.lang.core.types.lifetimeElidable
import org.rust.lang.utils.RsDiagnostic
import org.rust.lang.utils.addToHolder
Expand All @@ -23,8 +25,8 @@ class RsWrongLifetimeParametersNumberInspection : RsLocalInspectionTool() {
if (type.path?.cself != null) return

val paramsDecl = type.path?.reference?.resolve() as? RsGenericDeclaration ?: return
val expectedLifetimes = paramsDecl.typeParameterList?.lifetimeParameterList?.size ?: 0
val actualLifetimes = type.path?.typeArgumentList?.lifetimeList?.size ?: 0
val expectedLifetimes = paramsDecl.lifetimeParameters.size
val actualLifetimes = type.path?.lifetimeArguments?.size ?: 0
if (expectedLifetimes == actualLifetimes) return
if (actualLifetimes == 0 && !type.lifetimeElidable) {
RsDiagnostic.MissingLifetimeSpecifier(type).addToHolder(holder)
Expand Down
Expand Up @@ -173,7 +173,7 @@ class RsExtractFunctionConfig private constructor(
val type = it.declaredType
val bounds = mutableSetOf<Ty>()
it.bounds.flatMapTo(bounds) {
it.bound.traitRef?.path?.typeArgumentList?.typeReferenceList?.flatMap { it.type.types() } ?: emptyList()
it.bound.traitRef?.path?.typeArguments?.flatMap { it.type.types() }.orEmpty()
}
type to bounds
}
Expand Down
Expand Up @@ -15,8 +15,8 @@ interface RsGenericDeclaration : RsElement {
val RsGenericDeclaration.typeParameters: List<RsTypeParameter>
get() = typeParameterList?.typeParameterList.orEmpty()

val RsGenericDeclaration.constParameters: List<RsConstParameter>
get() = typeParameterList?.constParameterList.orEmpty()

val RsGenericDeclaration.lifetimeParameters: List<RsLifetimeParameter>
get() = typeParameterList?.lifetimeParameterList.orEmpty()

val RsGenericDeclaration.constParameters: List<RsConstParameter>
get() = typeParameterList?.constParameterList.orEmpty()
Expand Up @@ -22,5 +22,7 @@ val RsInferenceContextOwner.body: RsElement?
is RsFunction -> block
is RsVariantDiscriminant -> expr
is RsExpressionCodeFragment -> expr
is RsBaseType -> path?.typeArgumentList
is RsTraitRef -> path.typeArgumentList
else -> null
}
11 changes: 10 additions & 1 deletion src/main/kotlin/org/rust/lang/core/psi/ext/RsMethodCall.kt
Expand Up @@ -8,11 +8,20 @@ package org.rust.lang.core.psi.ext
import com.intellij.lang.ASTNode
import com.intellij.openapi.util.TextRange
import com.intellij.psi.PsiElement
import org.rust.lang.core.psi.RsExpr
import org.rust.lang.core.psi.RsLifetime
import org.rust.lang.core.psi.RsMethodCall
import org.rust.lang.core.psi.RsTypeReference
import org.rust.lang.core.resolve.ref.RsMethodCallReferenceImpl
import org.rust.lang.core.resolve.ref.RsReference

val RsMethodCall.textRangeWithoutValueArguments
val RsMethodCall.lifetimeArguments: List<RsLifetime> get() = typeArgumentList?.lifetimeList.orEmpty()

val RsMethodCall.typeArguments: List<RsTypeReference> get() = typeArgumentList?.typeReferenceList.orEmpty()

val RsMethodCall.constArguments: List<RsExpr> get() = typeArgumentList?.exprList.orEmpty()

val RsMethodCall.textRangeWithoutValueArguments: TextRange
get() = TextRange(startOffset, typeArgumentList?.endOffset ?: identifier.endOffset)

abstract class RsMethodCallImplMixin(node: ASTNode) : RsElementImpl(node), RsMethodCall {
Expand Down
6 changes: 6 additions & 0 deletions src/main/kotlin/org/rust/lang/core/psi/ext/RsPath.kt
Expand Up @@ -81,6 +81,12 @@ fun RsPath.allowedNamespaces(isCompletion: Boolean = false): Set<Namespace> = wh
else -> TYPES_N_VALUES
}

val RsPath.lifetimeArguments: List<RsLifetime> get() = typeArgumentList?.lifetimeList.orEmpty()

val RsPath.typeArguments: List<RsTypeReference> get() = typeArgumentList?.typeReferenceList.orEmpty()

val RsPath.constArguments: List<RsExpr> get() = typeArgumentList?.exprList.orEmpty()

abstract class RsPathImplMixin : RsStubbedElementImpl<RsPathStub>,
RsPath {
constructor(node: ASTNode) : super(node)
Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/org/rust/lang/core/psi/ext/RsTraitItem.kt
Expand Up @@ -95,7 +95,7 @@ val RsTraitItem.isSized: Boolean get() {
}

fun RsTraitItem.withSubst(vararg subst: Ty): BoundElement<RsTraitItem> {
val typeParameterList = typeParameterList?.typeParameterList.orEmpty()
val typeParameterList = typeParameters
val substitution = if (typeParameterList.size != subst.size) {
LOG.warn("Trait has ${typeParameterList.size} type parameters but received ${subst.size} types for substitution")
emptySubstitution
Expand Down
3 changes: 1 addition & 2 deletions src/main/kotlin/org/rust/lang/core/resolve/ImplLookup.kt
Expand Up @@ -27,8 +27,7 @@ import org.rust.stdext.buildList
import kotlin.LazyThreadSafetyMode.NONE

private val RsTraitItem.typeParamSingle: TyTypeParameter?
get() =
typeParameterList?.typeParameterList?.singleOrNull()?.let { TyTypeParameter.named(it) }
get() = typeParameters.singleOrNull()?.let { TyTypeParameter.named(it) }

const val DEFAULT_RECURSION_LIMIT = 64

Expand Down
Expand Up @@ -125,7 +125,8 @@ fun <T: RsElement> instantiatePathGenerics(
resolved: BoundElement<T>
): BoundElement<T> {
val (element, subst) = resolved.downcast<RsGenericDeclaration>() ?: return resolved
val typeArguments: List<Ty>? = run {

val typeArguments = run {
val inAngles = path.typeArgumentList
val fnSugar = path.valueParameterList
when {
Expand All @@ -136,7 +137,6 @@ fun <T: RsElement> instantiatePathGenerics(
else -> null
}
}
val regionArguments: List<Region>? = path.typeArgumentList?.lifetimeList?.map { it.resolve() }
val outputArg = path.retType?.typeReference?.type

val assocTypes = run {
Expand Down Expand Up @@ -186,7 +186,9 @@ fun <T: RsElement> instantiatePathGenerics(
}
paramTy to value
}

val regionParameters = element.lifetimeParameters.map { ReEarlyBound(it) }
val regionArguments = path.typeArgumentList?.lifetimeList?.map { it.resolve() }
val regionSubst = regionParameters.zip(regionArguments ?: regionParameters).toMap()
val newSubst = Substitution(typeSubst, regionSubst)
return BoundElement(resolved.element, subst + newSubst, assocTypes)
Expand Down
74 changes: 44 additions & 30 deletions src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt
Expand Up @@ -15,6 +15,7 @@ import org.rust.lang.core.psi.*
import org.rust.lang.core.psi.ext.*
import org.rust.lang.core.resolve.*
import org.rust.lang.core.resolve.ref.MethodResolveVariant
import org.rust.lang.core.resolve.ref.resolvePathRaw
import org.rust.lang.core.types.*
import org.rust.lang.core.types.regions.Region
import org.rust.lang.core.types.ty.*
Expand Down Expand Up @@ -171,40 +172,53 @@ class RsInferenceContext(
}

fun infer(element: RsInferenceContextOwner): RsInferenceResult {
if (element is RsFunction) {
val fctx = RsTypeInferenceWalker(this, element.returnType)
fctx.extractParameterBindings(element)
element.block?.let { fctx.inferFnBody(it) }
} else if (element is RsReplCodeFragment) {
element.context.inference?.let {
patTypes.putAll(it.patTypes)
patFieldTypes.putAll(it.patFieldTypes)
exprTypes.putAll(it.exprTypes)
when (element) {
is RsFunction -> {
val fctx = RsTypeInferenceWalker(this, element.returnType)
fctx.extractParameterBindings(element)
element.block?.let { fctx.inferFnBody(it) }
}

val walker = RsTypeInferenceWalker(this, TyUnknown)
walker.inferReplCodeFragment(element)
} else {
val (retTy, expr) = when (element) {
is RsConstant -> element.typeReference?.type to element.expr
is RsArrayType -> TyInteger.USize to element.expr
is RsVariantDiscriminant -> {
val enum = element.ancestorStrict<RsEnumItem>()
enum?.reprType to element.expr
is RsReplCodeFragment -> {
element.context.inference?.let {
patTypes.putAll(it.patTypes)
patFieldTypes.putAll(it.patFieldTypes)
exprTypes.putAll(it.exprTypes)
}
is RsExpressionCodeFragment -> {
element.context.inference?.let {
patTypes.putAll(it.patTypes)
patFieldTypes.putAll(it.patFieldTypes)
exprTypes.putAll(it.exprTypes)
}
null to element.expr
RsTypeInferenceWalker(this, TyUnknown).inferReplCodeFragment(element)
}
is RsBaseType, is RsTraitRef -> {
val path = when (element) {
is RsBaseType -> element.path
is RsTraitRef -> element.path
else -> null
}
else -> error("Type inference is not implemented for PSI element of type " +
"`${element.javaClass}` that implement `RsInferenceContextOwner`")
val declaration = path?.let { resolvePathRaw(it, lookup) }?.singleOrNull()?.element as? RsGenericDeclaration
val constParameters = declaration?.constParameters.orEmpty()
val constArguments = path?.constArguments.orEmpty()
RsTypeInferenceWalker(this, TyUnknown).inferConstArgumentTypes(constParameters, constArguments)
}
if (expr != null) {
RsTypeInferenceWalker(this, retTy ?: TyUnknown).inferLambdaBody(expr)
else -> {
val (retTy, expr) = when (element) {
is RsConstant -> element.typeReference?.type to element.expr
is RsArrayType -> TyInteger.USize to element.expr
is RsVariantDiscriminant -> {
val enum = element.contextStrict<RsEnumItem>()
enum?.reprType to element.expr
}
is RsExpressionCodeFragment -> {
element.context.inference?.let {
patTypes.putAll(it.patTypes)
patFieldTypes.putAll(it.patFieldTypes)
exprTypes.putAll(it.exprTypes)
}
null to element.expr
}
else -> error("Type inference is not implemented for PSI element of type " +
"`${element.javaClass}` that implement `RsInferenceContextOwner`")
}
if (expr != null) {
RsTypeInferenceWalker(this, retTy ?: TyUnknown).inferLambdaBody(expr)
}
}
}

Expand Down
Expand Up @@ -392,7 +392,13 @@ class RsTypeInferenceWalker(
scopeEntry: ScopeEntry,
pathExpr: RsPathExpr
): Ty {
val subst = instantiatePathGenerics(pathExpr.path, BoundElement(element, scopeEntry.subst)).subst
val path = pathExpr.path
val subst = instantiatePathGenerics(path, BoundElement(element, scopeEntry.subst)).subst

if (element is RsGenericDeclaration) {
inferConstArgumentTypes(element.constParameters, path.constArguments)
}

val type = when (element) {
is RsPatBinding -> ctx.getBindingType(element)
is RsTypeDeclarationElement -> element.declaredType
Expand Down Expand Up @@ -579,6 +585,8 @@ class RsTypeInferenceWalker(
return methodType.retType
}

inferConstArgumentTypes(callee.element.constParameters, methodCall.constArguments)

ctx.addDerefAdjustments(methodCall.receiver, callee.derefChain)
if (callee.borrow != null) {
ctx.addAdjustment(methodCall.receiver, Adjustment.BorrowReference(callee.methodSelfTy as TyReference))
Expand Down Expand Up @@ -731,6 +739,10 @@ class RsTypeInferenceWalker(
}
}

fun inferConstArgumentTypes(constParameters: List<RsConstParameter>, constArguments: List<RsExpr>) {
inferArgumentTypes(constParameters.map { it.typeReference?.type ?: TyUnknown }, constArguments)
}

private fun inferFieldExprType(receiver: Ty, fieldLookup: RsFieldLookup): Ty {
if (fieldLookup.identifier?.text == "await" && fieldLookup.isEdition2018) {
return receiver.lookupFutureOutputTy(lookup)
Expand Down
Expand Up @@ -27,6 +27,21 @@ class RsTypeCheckInspectionTest : RsInspectionsTestBase(RsTypeCheckInspection::c
const A: [u8; <error>1u8</error>] = [0];
""")

fun `test typecheck in const argument`() = checkByText("""
#![feature(const_generics)]
struct S<const N: usize>;
trait T<const N: usize> {
fn foo<const N: usize>(&self) -> S<{ N }>;
}
impl T<<error>1u8</error>> for S<<error>1u8</error>> {
fn foo<const N: usize>(self) -> S<<error>1u8</error>> { self }
}
fn bar(x: S<<error>1u8</error>>) -> S<<error>1u8</error>> {
let s: S<<error>1u8</error>> = S::<<error>1u8</error>>;
s.foo::<<error>1u8</error>>()
}
""")

fun `test typecheck in enum variant discriminant`() = checkByText("""
enum Foo { BAR = <error>1u8</error> }
""")
Expand Down

0 comments on commit 0849b00

Please sign in to comment.