Skip to content

Commit

Permalink
TY: Check type of const generic argument
Browse files Browse the repository at this point in the history
  • Loading branch information
mchernyavsky committed Dec 20, 2019
1 parent a8df8aa commit 0eee38c
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 32 deletions.
4 changes: 3 additions & 1 deletion src/main/grammars/RustParser.bnf
Expand Up @@ -240,6 +240,7 @@ ValuePathGenericArgsNoTypeQual ::= <<typeQuals 'OFF' <<pathMode 'VALUE' PathImpl

// Path semantically constrained to resolve to a trait
TraitRef ::= TypePathGenericArgsNoTypeQual {
implements = "org.rust.lang.core.psi.ext.RsInferenceContextOwner"
extends = "org.rust.lang.core.psi.ext.RsStubbedElementImpl<?>"
stubClass = "org.rust.lang.core.stubs.RsPlaceholderStub"
elementTypeFactory = "org.rust.lang.core.stubs.StubImplementationsKt.factory"
Expand Down Expand Up @@ -891,7 +892,8 @@ ForInType ::= ForLifetimes (FnPointerType | TraitRef) {
}

BaseType ::= TrivialBaseTypeInner | TypePathGenericArgs {
implements = "org.rust.lang.core.psi.ext.RsTypeElement"
implements = [ "org.rust.lang.core.psi.ext.RsTypeElement"
"org.rust.lang.core.psi.ext.RsInferenceContextOwner" ]
extends = "org.rust.lang.core.psi.ext.RsStubbedElementImpl<?>"
stubClass = "org.rust.lang.core.stubs.RsBaseTypeStub"
elementTypeFactory = "org.rust.lang.core.stubs.StubImplementationsKt.factory"
Expand Down
Expand Up @@ -15,6 +15,8 @@ abstract class RsDiagnosticBasedInspection : RsLocalInspectionTool() {
override fun visitFunction(o: RsFunction) = collectDiagnostics(holder, o)
override fun visitConstant(o: RsConstant) = collectDiagnostics(holder, o)
override fun visitArrayType(o: RsArrayType) = collectDiagnostics(holder, o)
override fun visitBaseType(o: RsBaseType) = collectDiagnostics(holder, o)
override fun visitTraitRef(o: RsTraitRef) = collectDiagnostics(holder, o)
override fun visitVariantDiscriminant(o: RsVariantDiscriminant) = collectDiagnostics(holder, o)
}

Expand Down
Expand Up @@ -22,5 +22,7 @@ val RsInferenceContextOwner.body: RsElement?
is RsFunction -> block
is RsVariantDiscriminant -> expr
is RsExpressionCodeFragment -> expr
is RsBaseType -> path?.typeArgumentList
is RsTraitRef -> path.typeArgumentList
else -> null
}
74 changes: 44 additions & 30 deletions src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt
Expand Up @@ -15,6 +15,7 @@ import org.rust.lang.core.psi.*
import org.rust.lang.core.psi.ext.*
import org.rust.lang.core.resolve.*
import org.rust.lang.core.resolve.ref.MethodResolveVariant
import org.rust.lang.core.resolve.ref.resolvePathRaw
import org.rust.lang.core.types.*
import org.rust.lang.core.types.regions.Region
import org.rust.lang.core.types.ty.*
Expand Down Expand Up @@ -171,40 +172,53 @@ class RsInferenceContext(
}

fun infer(element: RsInferenceContextOwner): RsInferenceResult {
if (element is RsFunction) {
val fctx = RsTypeInferenceWalker(this, element.returnType)
fctx.extractParameterBindings(element)
element.block?.let { fctx.inferFnBody(it) }
} else if (element is RsReplCodeFragment) {
element.context.inference?.let {
patTypes.putAll(it.patTypes)
patFieldTypes.putAll(it.patFieldTypes)
exprTypes.putAll(it.exprTypes)
when (element) {
is RsFunction -> {
val fctx = RsTypeInferenceWalker(this, element.returnType)
fctx.extractParameterBindings(element)
element.block?.let { fctx.inferFnBody(it) }
}

val walker = RsTypeInferenceWalker(this, TyUnknown)
walker.inferReplCodeFragment(element)
} else {
val (retTy, expr) = when (element) {
is RsConstant -> element.typeReference?.type to element.expr
is RsArrayType -> TyInteger.USize to element.expr
is RsVariantDiscriminant -> {
val enum = element.ancestorStrict<RsEnumItem>()
enum?.reprType to element.expr
is RsReplCodeFragment -> {
element.context.inference?.let {
patTypes.putAll(it.patTypes)
patFieldTypes.putAll(it.patFieldTypes)
exprTypes.putAll(it.exprTypes)
}
is RsExpressionCodeFragment -> {
element.context.inference?.let {
patTypes.putAll(it.patTypes)
patFieldTypes.putAll(it.patFieldTypes)
exprTypes.putAll(it.exprTypes)
}
null to element.expr
RsTypeInferenceWalker(this, TyUnknown).inferReplCodeFragment(element)
}
is RsBaseType, is RsTraitRef -> {
val path = when (element) {
is RsBaseType -> element.path
is RsTraitRef -> element.path
else -> null
}
else -> error("Type inference is not implemented for PSI element of type " +
"`${element.javaClass}` that implement `RsInferenceContextOwner`")
val declaration = path?.let { resolvePathRaw(it, lookup) }?.singleOrNull()?.element as? RsGenericDeclaration
val constParameters = declaration?.constParameters.orEmpty()
val constArguments = path?.constArguments.orEmpty()
RsTypeInferenceWalker(this, TyUnknown).inferConstArgumentTypes(constParameters, constArguments)
}
if (expr != null) {
RsTypeInferenceWalker(this, retTy ?: TyUnknown).inferLambdaBody(expr)
else -> {
val (retTy, expr) = when (element) {
is RsConstant -> element.typeReference?.type to element.expr
is RsArrayType -> TyInteger.USize to element.expr
is RsVariantDiscriminant -> {
val enum = element.contextStrict<RsEnumItem>()
enum?.reprType to element.expr
}
is RsExpressionCodeFragment -> {
element.context.inference?.let {
patTypes.putAll(it.patTypes)
patFieldTypes.putAll(it.patFieldTypes)
exprTypes.putAll(it.exprTypes)
}
null to element.expr
}
else -> error("Type inference is not implemented for PSI element of type " +
"`${element.javaClass}` that implement `RsInferenceContextOwner`")
}
if (expr != null) {
RsTypeInferenceWalker(this, retTy ?: TyUnknown).inferLambdaBody(expr)
}
}
}

Expand Down
Expand Up @@ -392,7 +392,13 @@ class RsTypeInferenceWalker(
scopeEntry: ScopeEntry,
pathExpr: RsPathExpr
): Ty {
val subst = instantiatePathGenerics(pathExpr.path, BoundElement(element, scopeEntry.subst)).subst
val path = pathExpr.path
val subst = instantiatePathGenerics(path, BoundElement(element, scopeEntry.subst)).subst

if (element is RsGenericDeclaration) {
inferConstArgumentTypes(element.constParameters, path.constArguments)
}

val type = when (element) {
is RsPatBinding -> ctx.getBindingType(element)
is RsTypeDeclarationElement -> element.declaredType
Expand Down Expand Up @@ -579,6 +585,8 @@ class RsTypeInferenceWalker(
return methodType.retType
}

inferConstArgumentTypes(callee.element.constParameters, methodCall.constArguments)

ctx.addDerefAdjustments(methodCall.receiver, callee.derefChain)
if (callee.borrow != null) {
ctx.addAdjustment(methodCall.receiver, Adjustment.BorrowReference(callee.methodSelfTy as TyReference))
Expand Down Expand Up @@ -731,6 +739,10 @@ class RsTypeInferenceWalker(
}
}

fun inferConstArgumentTypes(constParameters: List<RsConstParameter>, constArguments: List<RsExpr>) {
inferArgumentTypes(constParameters.map { it.typeReference?.type ?: TyUnknown }, constArguments)
}

private fun inferFieldExprType(receiver: Ty, fieldLookup: RsFieldLookup): Ty {
if (fieldLookup.identifier?.text == "await" && fieldLookup.isEdition2018) {
return receiver.lookupFutureOutputTy(lookup)
Expand Down
Expand Up @@ -27,6 +27,21 @@ class RsTypeCheckInspectionTest : RsInspectionsTestBase(RsTypeCheckInspection::c
const A: [u8; <error>1u8</error>] = [0];
""")

fun `test typecheck in const argument`() = checkByText("""
#![feature(const_generics)]
struct S<const N: usize>;
trait T<const N: usize> {
fn foo<const N: usize>(&self) -> S<{ N }>;
}
impl T<<error>1u8</error>> for S<<error>1u8</error>> {
fn foo<const N: usize>(self) -> S<<error>1u8</error>> { self }
}
fn bar(x: S<<error>1u8</error>>) -> S<<error>1u8</error>> {
let s: S<<error>1u8</error>> = S::<<error>1u8</error>>;
s.foo::<<error>1u8</error>>()
}
""")

fun `test typecheck in enum variant discriminant`() = checkByText("""
enum Foo { BAR = <error>1u8</error> }
""")
Expand Down

0 comments on commit 0eee38c

Please sign in to comment.