Skip to content

Commit

Permalink
TY: Check type of const generic argument
Browse files Browse the repository at this point in the history
  • Loading branch information
mchernyavsky committed Dec 20, 2019
1 parent a8df8aa commit 2e8b363
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 36 deletions.
4 changes: 3 additions & 1 deletion src/main/grammars/RustParser.bnf
Expand Up @@ -240,6 +240,7 @@ ValuePathGenericArgsNoTypeQual ::= <<typeQuals 'OFF' <<pathMode 'VALUE' PathImpl

// Path semantically constrained to resolve to a trait
TraitRef ::= TypePathGenericArgsNoTypeQual {
implements = "org.rust.lang.core.psi.ext.RsInferenceContextOwner"
extends = "org.rust.lang.core.psi.ext.RsStubbedElementImpl<?>"
stubClass = "org.rust.lang.core.stubs.RsPlaceholderStub"
elementTypeFactory = "org.rust.lang.core.stubs.StubImplementationsKt.factory"
Expand Down Expand Up @@ -891,7 +892,8 @@ ForInType ::= ForLifetimes (FnPointerType | TraitRef) {
}

BaseType ::= TrivialBaseTypeInner | TypePathGenericArgs {
implements = "org.rust.lang.core.psi.ext.RsTypeElement"
implements = [ "org.rust.lang.core.psi.ext.RsTypeElement"
"org.rust.lang.core.psi.ext.RsInferenceContextOwner" ]
extends = "org.rust.lang.core.psi.ext.RsStubbedElementImpl<?>"
stubClass = "org.rust.lang.core.stubs.RsBaseTypeStub"
elementTypeFactory = "org.rust.lang.core.stubs.StubImplementationsKt.factory"
Expand Down
Expand Up @@ -15,6 +15,8 @@ abstract class RsDiagnosticBasedInspection : RsLocalInspectionTool() {
override fun visitFunction(o: RsFunction) = collectDiagnostics(holder, o)
override fun visitConstant(o: RsConstant) = collectDiagnostics(holder, o)
override fun visitArrayType(o: RsArrayType) = collectDiagnostics(holder, o)
override fun visitBaseType(o: RsBaseType) = collectDiagnostics(holder, o)
override fun visitTraitRef(o: RsTraitRef) = collectDiagnostics(holder, o)
override fun visitVariantDiscriminant(o: RsVariantDiscriminant) = collectDiagnostics(holder, o)
}

Expand Down
Expand Up @@ -22,5 +22,7 @@ val RsInferenceContextOwner.body: RsElement?
is RsFunction -> block
is RsVariantDiscriminant -> expr
is RsExpressionCodeFragment -> expr
is RsBaseType -> path?.typeArgumentList
is RsTraitRef -> path.typeArgumentList
else -> null
}
74 changes: 44 additions & 30 deletions src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt
Expand Up @@ -15,6 +15,7 @@ import org.rust.lang.core.psi.*
import org.rust.lang.core.psi.ext.*
import org.rust.lang.core.resolve.*
import org.rust.lang.core.resolve.ref.MethodResolveVariant
import org.rust.lang.core.resolve.ref.resolvePathRaw
import org.rust.lang.core.types.*
import org.rust.lang.core.types.regions.Region
import org.rust.lang.core.types.ty.*
Expand Down Expand Up @@ -171,40 +172,53 @@ class RsInferenceContext(
}

fun infer(element: RsInferenceContextOwner): RsInferenceResult {
if (element is RsFunction) {
val fctx = RsTypeInferenceWalker(this, element.returnType)
fctx.extractParameterBindings(element)
element.block?.let { fctx.inferFnBody(it) }
} else if (element is RsReplCodeFragment) {
element.context.inference?.let {
patTypes.putAll(it.patTypes)
patFieldTypes.putAll(it.patFieldTypes)
exprTypes.putAll(it.exprTypes)
when (element) {
is RsFunction -> {
val fctx = RsTypeInferenceWalker(this, element.returnType)
fctx.extractParameterBindings(element)
element.block?.let { fctx.inferFnBody(it) }
}

val walker = RsTypeInferenceWalker(this, TyUnknown)
walker.inferReplCodeFragment(element)
} else {
val (retTy, expr) = when (element) {
is RsConstant -> element.typeReference?.type to element.expr
is RsArrayType -> TyInteger.USize to element.expr
is RsVariantDiscriminant -> {
val enum = element.ancestorStrict<RsEnumItem>()
enum?.reprType to element.expr
is RsReplCodeFragment -> {
element.context.inference?.let {
patTypes.putAll(it.patTypes)
patFieldTypes.putAll(it.patFieldTypes)
exprTypes.putAll(it.exprTypes)
}
is RsExpressionCodeFragment -> {
element.context.inference?.let {
patTypes.putAll(it.patTypes)
patFieldTypes.putAll(it.patFieldTypes)
exprTypes.putAll(it.exprTypes)
}
null to element.expr
RsTypeInferenceWalker(this, TyUnknown).inferReplCodeFragment(element)
}
is RsBaseType, is RsTraitRef -> {
val path = when (element) {
is RsBaseType -> element.path
is RsTraitRef -> element.path
else -> null
}
else -> error("Type inference is not implemented for PSI element of type " +
"`${element.javaClass}` that implement `RsInferenceContextOwner`")
val declaration = path?.let { resolvePathRaw(it, lookup) }?.singleOrNull()?.element as? RsGenericDeclaration
val constParameters = declaration?.constParameters.orEmpty()
val constArguments = path?.constArguments.orEmpty()
RsTypeInferenceWalker(this, TyUnknown).inferConstArgumentTypes(constParameters, constArguments)
}
if (expr != null) {
RsTypeInferenceWalker(this, retTy ?: TyUnknown).inferLambdaBody(expr)
else -> {
val (retTy, expr) = when (element) {
is RsConstant -> element.typeReference?.type to element.expr
is RsArrayType -> TyInteger.USize to element.expr
is RsVariantDiscriminant -> {
val enum = element.contextStrict<RsEnumItem>()
enum?.reprType to element.expr
}
is RsExpressionCodeFragment -> {
element.context.inference?.let {
patTypes.putAll(it.patTypes)
patFieldTypes.putAll(it.patFieldTypes)
exprTypes.putAll(it.exprTypes)
}
null to element.expr
}
else -> error("Type inference is not implemented for PSI element of type " +
"`${element.javaClass}` that implement `RsInferenceContextOwner`")
}
if (expr != null) {
RsTypeInferenceWalker(this, retTy ?: TyUnknown).inferLambdaBody(expr)
}
}
}

Expand Down
Expand Up @@ -334,6 +334,7 @@ class RsTypeInferenceWalker(
subst = subst,
source = TraitImplSource.Collapsed((resolved.owner as RsAbstractableOwner.Trait).trait)
)
inferConstArgumentTypes(resolved.constParameters, path.constArguments)
return instantiatePath(resolved, scopeEntry, expr)
}
}
Expand All @@ -345,7 +346,13 @@ class RsTypeInferenceWalker(
ctx.writePath(expr, filteredVariants.mapNotNull { ResolvedPath.from(it) })

val first = filteredVariants.singleOrNull() ?: return TyUnknown
return instantiatePath(first.element ?: return TyUnknown, first, expr)
val resolved = first.element ?: return TyUnknown

if (resolved is RsGenericDeclaration) {
inferConstArgumentTypes(resolved.constParameters, path.constArguments)
}

return instantiatePath(resolved, first, expr)
}

/** This works for `String::from` where multiple impls of `From` trait found for `String` */
Expand Down Expand Up @@ -573,37 +580,41 @@ class RsTypeInferenceWalker(

callee ?: variants.firstOrNull()?.let { MethodPick.from(it) }
}

if (callee == null) {
val methodType = unknownTyFunction(argExprs.size)
inferArgumentTypes(methodType.paramTypes, argExprs)
return methodType.retType
}

val resolved = callee.element
inferConstArgumentTypes(resolved.constParameters, methodCall.constArguments)

ctx.addDerefAdjustments(methodCall.receiver, callee.derefChain)
if (callee.borrow != null) {
ctx.addAdjustment(methodCall.receiver, Adjustment.BorrowReference(callee.methodSelfTy as TyReference))
}

var typeParameters = ctx.instantiateMethodOwnerSubstitution(callee, methodCall)
typeParameters = ctx.instantiateBounds(callee.element, callee.formalSelfTy, typeParameters)
typeParameters = ctx.instantiateBounds(resolved, callee.formalSelfTy, typeParameters)

val fnSubst = run {
val typeArguments = methodCall.typeArgumentList?.typeReferenceList.orEmpty().map { it.type }
if (typeArguments.isEmpty()) {
emptySubstitution
} else {
val parameters = callee.element.typeParameterList?.typeParameterList.orEmpty()
val parameters = resolved.typeParameterList?.typeParameterList.orEmpty()
.map { TyTypeParameter.named(it) }
parameters.zip(typeArguments).toMap().toTypeSubst()
}
}

unifySubst(fnSubst, typeParameters)

val methodType = (callee.element.type)
val methodType = (resolved.type)
.substitute(typeParameters)
.foldWith(associatedTypeNormalizer) as TyFunction
if (expected != null && !callee.element.isAsync) ctx.combineTypes(expected, methodType.retType)
if (expected != null && !resolved.isAsync) ctx.combineTypes(expected, methodType.retType)
// drop first element of paramTypes because it's `self` param
// and it doesn't have value in `methodCall.valueArgumentList.exprList`
inferArgumentTypes(methodType.paramTypes.drop(1), argExprs)
Expand Down Expand Up @@ -731,6 +742,10 @@ class RsTypeInferenceWalker(
}
}

fun inferConstArgumentTypes(constParameters: List<RsConstParameter>, constArguments: List<RsExpr>) {
inferArgumentTypes(constParameters.map { it.typeReference?.type ?: TyUnknown }, constArguments)
}

private fun inferFieldExprType(receiver: Ty, fieldLookup: RsFieldLookup): Ty {
if (fieldLookup.identifier?.text == "await" && fieldLookup.isEdition2018) {
return receiver.lookupFutureOutputTy(lookup)
Expand Down
Expand Up @@ -27,6 +27,21 @@ class RsTypeCheckInspectionTest : RsInspectionsTestBase(RsTypeCheckInspection::c
const A: [u8; <error>1u8</error>] = [0];
""")

fun `test typecheck in const argument`() = checkByText("""
#![feature(const_generics)]
struct S<const N: usize>;
trait T<const N: usize> {
fn foo<const N: usize>(&self) -> S<{ N }>;
}
impl T<<error>1u8</error>> for S<<error>1u8</error>> {
fn foo<const N: usize>(self) -> S<<error>1u8</error>> { self }
}
fn bar(x: S<<error>1u8</error>>) -> S<<error>1u8</error>> {
let s: S<<error>1u8</error>> = S::<<error>1u8</error>>;
s.foo::<<error>1u8</error>>()
}
""")

fun `test typecheck in enum variant discriminant`() = checkByText("""
enum Foo { BAR = <error>1u8</error> }
""")
Expand Down

0 comments on commit 2e8b363

Please sign in to comment.