Skip to content

Commit

Permalink
TY&RES: support const arguments that looks like type arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
vlad20012 committed Jul 8, 2021
1 parent 48b62d4 commit 49e607b
Show file tree
Hide file tree
Showing 7 changed files with 294 additions and 69 deletions.
5 changes: 4 additions & 1 deletion src/main/kotlin/org/rust/ide/presentation/RsPsiRenderer.kt
Expand Up @@ -610,7 +610,10 @@ open class PsiSubstitutingPsiRenderer(
}
is RsConstParameter -> when (val s = subst.constSubst[resolved]) {
is RsPsiSubstitution.Value.Present -> {
appendConstExpr(sb, s.value)
when (s.value) {
is RsExpr -> appendConstExpr(sb, s.value)
is RsTypeReference -> appendTypeReference(sb, s.value)
}
true
}
else -> false
Expand Down
11 changes: 10 additions & 1 deletion src/main/kotlin/org/rust/lang/core/psi/ext/RsPath.kt
Expand Up @@ -69,7 +69,16 @@ val RsPath.qualifier: RsPath?
}

fun RsPath.allowedNamespaces(isCompletion: Boolean = false): Set<Namespace> = when (val parent = parent) {
is RsPath, is RsTypeReference, is RsTraitRef, is RsStructLiteral, is RsPatStruct -> TYPES
is RsPath, is RsTraitRef, is RsStructLiteral, is RsPatStruct -> TYPES
is RsTypeReference -> when (parent.parent) {
is RsTypeArgumentList -> when {
// type A = Foo<T>
// ~ `T` can be either type or const argument
typeArgumentList == null && valueParameterList == null -> TYPES_N_VALUES
else -> TYPES
}
else -> TYPES
}
is RsUseSpeck -> when {
// use foo::bar::{self, baz};
// ~~~~~~~~
Expand Down
159 changes: 97 additions & 62 deletions src/main/kotlin/org/rust/lang/core/resolve/ref/RsPathReferenceImpl.kt
Expand Up @@ -11,16 +11,18 @@ import com.intellij.util.containers.map2Array
import org.rust.lang.core.psi.*
import org.rust.lang.core.psi.ext.*
import org.rust.lang.core.resolve.*
import org.rust.lang.core.types.*
import org.rust.lang.core.types.RsPsiSubstitution.TypeValue
import org.rust.lang.core.types.RsPsiSubstitution.Value
import org.rust.lang.core.types.*
import org.rust.lang.core.types.infer.ResolvedPath
import org.rust.lang.core.types.infer.foldTyInferWith
import org.rust.lang.core.types.infer.substitute
import org.rust.lang.core.types.ty.*
import org.rust.lang.core.types.ty.TyInfer
import org.rust.lang.core.types.ty.TyUnknown
import org.rust.lang.utils.evaluation.PathExprResolver
import org.rust.stdext.buildMap
import org.rust.stdext.intersects
import org.rust.stdext.mapNotNullToSet

class RsPathReferenceImpl(
element: RsPath
Expand Down Expand Up @@ -84,18 +86,38 @@ class RsPathReferenceImpl(
it.inner.element
}

private fun advancedMultiResolve(): List<BoundElementWithVisibility<RsElement>> =
advancedMultiresolveUsingInferenceCache() ?: advancedCachedMultiResolve()
private fun advancedMultiResolve(): List<BoundElementWithVisibility<RsElement>> {
return when (val parent = element.parent) {
is RsPathExpr -> advancedMultiResolveUsingInferenceCache(parent)
is RsTypeReference -> when (val parentParent = parent.parent) {
is RsTypeArgumentList -> resolveTypeOrConstArg(parentParent, parent)
else -> advancedCachedMultiResolve()
}
else -> advancedCachedMultiResolve()
}
}

private fun advancedMultiresolveUsingInferenceCache(): List<BoundElementWithVisibility<RsElement>>? {
val path = element.parent as? RsPathExpr ?: return null
return path.inference?.getResolvedPath(path)?.map { result ->
private fun advancedMultiResolveUsingInferenceCache(pathExpr: RsPathExpr): List<BoundElementWithVisibility<RsElement>> {
val inference = pathExpr.inference ?: return emptyList()
return inference.getResolvedPath(pathExpr).map { result ->
val element = BoundElement(result.element, result.subst)
val isVisible = (result as? ResolvedPath.Item)?.isVisible ?: true
BoundElementWithVisibility(element, isVisible)
}
}

private fun resolveTypeOrConstArg(tal: RsTypeArgumentList, parent: RsTypeReference): List<BoundElementWithVisibility<RsElement>> {
val result = advancedCachedMultiResolve()
return when (result.size) {
0 -> emptyList()
1 -> result
else -> {
val withoutConstants = result.filter { it.inner.element !is RsConstant && it.inner.element !is RsConstParameter }
withoutConstants.ifEmpty { result }
}
}
}

private fun advancedCachedMultiResolve(): List<BoundElementWithVisibility<RsElement>> {
return RsResolveCache.getInstance(element.project)
.resolveWithCaching(element, ResolveCacheDependency.LOCAL_AND_RUST_STRUCTURE, Resolver)
Expand Down Expand Up @@ -185,50 +207,32 @@ fun <T : RsElement> instantiatePathGenerics(
fun pathPsiSubst(path: RsPath, resolved: RsGenericDeclaration): RsPsiSubstitution {
val args = pathTypeParameters(path)

val typeArguments = when (args) {
is RsPsiPathParameters.InAngles -> args.args.map { TypeValue.Present.InAngles(it) }
is RsPsiPathParameters.FnSugar -> listOf(TypeValue.Present.FnSugar(args.inputArgs))
null -> null
}

val assocTypes = run {
if (resolved is RsTraitItem) {
when (args) {
// Iterator<Item=T>
is RsPsiPathParameters.InAngles -> buildMap {
args.assoc.forEach { binding ->
// We can't just use `binding.reference.resolve()` here because
// resolving of an assoc type depends on a parent path resolve,
// so we coming back here and entering the infinite recursion
resolveAssocTypeBinding(resolved, binding)?.let { assoc ->
binding.typeReference?.let { put(assoc, it) }
}

}
}
// Fn() -> T
is RsPsiPathParameters.FnSugar -> buildMap {
if (args.outputArg != null) {
val outputParam = path.knownItems.FnOnce?.findAssociatedType("Output")
if (outputParam != null) {
put(outputParam, args.outputArg)
}
}
}
null -> emptyMap()
}
} else {
emptyMap<RsTypeAlias, RsTypeReference>()
}
}

val parent = path.parent

// Generic arguments are optional in expression context, e.g.
// `let a = Foo::<u8>::bar::<u16>();` can be written as `let a = Foo::bar();`
// if it is possible to infer `u8` and `u16` during type inference
val areOptionalArgs = parent is RsExpr || parent is RsPath && parent.parent is RsExpr

val regionParameters = resolved.lifetimeParameters
val regionArguments = (args as? RsPsiPathParameters.InAngles)?.lifetimeArgs
val regionSubst = regionParameters.withIndex().associate { (i, param) ->
val value = if (areOptionalArgs && regionArguments == null) {
Value.OptionalAbsent
} else if (regionArguments != null && i < regionArguments.size) {
Value.Present(regionArguments[i])
} else {
Value.RequiredAbsent
}
param to value
}

val typeArguments = when (args) {
is RsPsiPathParameters.InAngles -> args.typeOrConstArgs.filterIsInstance<RsTypeReference>().map { TypeValue.Present.InAngles(it) }
is RsPsiPathParameters.FnSugar -> listOf(TypeValue.Present.FnSugar(args.inputArgs))
null -> null
}

val typeSubst = resolved.typeParameters.withIndex().associate { (i, param) ->
val value = if (areOptionalArgs && typeArguments == null) {
// Args are optional and turbofish is not presend. E.g. `Vec::new()`
Expand Down Expand Up @@ -259,21 +263,11 @@ fun pathPsiSubst(path: RsPath, resolved: RsGenericDeclaration): RsPsiSubstitutio
param to value
}

val regionParameters = resolved.lifetimeParameters
val regionArguments = path.typeArgumentList?.lifetimeList
val regionSubst = regionParameters.withIndex().associate { (i, param) ->
val value = if (areOptionalArgs && regionArguments == null) {
Value.OptionalAbsent
} else if (regionArguments != null && i < regionArguments.size) {
Value.Present(regionArguments[i])
} else {
Value.RequiredAbsent
}
param to value
}
val usedTypeArguments = typeSubst.values.mapNotNullToSet { (it as? TypeValue.Present.InAngles)?.value }

val constParameters = resolved.constParameters
val constArguments = path.typeArgumentList?.exprList
val constArguments = (args as? RsPsiPathParameters.InAngles)?.typeOrConstArgs
?.let { list -> list.filter { it !is RsTypeReference || it !in usedTypeArguments && it is RsBaseType} }
val constSubst = constParameters.withIndex().associate { (i, param) ->
val value = if (areOptionalArgs && constArguments == null) {
Value.OptionalAbsent
Expand All @@ -285,13 +279,46 @@ fun pathPsiSubst(path: RsPath, resolved: RsGenericDeclaration): RsPsiSubstitutio
param to value
}

val assocTypes = run {
if (resolved is RsTraitItem) {
when (args) {
// Iterator<Item=T>
is RsPsiPathParameters.InAngles -> buildMap {
args.assoc.forEach { binding ->
// We can't just use `binding.reference.resolve()` here because
// resolving of an assoc type depends on a parent path resolve,
// so we coming back here and entering the infinite recursion
resolveAssocTypeBinding(resolved, binding)?.let { assoc ->
binding.typeReference?.let { put(assoc, it) }
}

}
}
// Fn() -> T
is RsPsiPathParameters.FnSugar -> buildMap {
if (args.outputArg != null) {
val outputParam = path.knownItems.FnOnce?.findAssociatedType("Output")
if (outputParam != null) {
put(outputParam, args.outputArg)
}
}
}
null -> emptyMap()
}
} else {
emptyMap<RsTypeAlias, RsTypeReference>()
}
}

return RsPsiSubstitution(typeSubst, regionSubst, constSubst, assocTypes)
}

private sealed class RsPsiPathParameters {
/** Foo<Bar, Baz, Item=i32> */
/** `Foo<'a, Bar, Baz, 2+2, Item=i32>` */
class InAngles(
val args: List<RsTypeReference>,
val lifetimeArgs: List<RsLifetime>,
/** [RsTypeReference] or [RsExpr] */
val typeOrConstArgs: List<RsElement>,
val assoc: List<RsAssocTypeBinding>
) : RsPsiPathParameters()

Expand All @@ -307,9 +334,17 @@ private fun pathTypeParameters(path: RsPath): RsPsiPathParameters? {
val fnSugar = path.valueParameterList
return when {
inAngles != null -> {
val params = inAngles.typeReferenceList
val assoc = inAngles.assocTypeBindingList
RsPsiPathParameters.InAngles(params, assoc)
val typeOrConstArgs = mutableListOf<RsElement>()
val lifetimeArgs = mutableListOf<RsLifetime>()
val assoc = mutableListOf<RsAssocTypeBinding>()
for (child in inAngles.stubChildrenOfType<RsElement>()) {
when (child) {
is RsTypeReference, is RsExpr -> typeOrConstArgs.add(child as RsElement)
is RsLifetime -> lifetimeArgs += child
is RsAssocTypeBinding -> assoc += child
}
}
RsPsiPathParameters.InAngles(lifetimeArgs, typeOrConstArgs, assoc)
}
fnSugar != null -> {
RsPsiPathParameters.FnSugar(
Expand Down
21 changes: 19 additions & 2 deletions src/main/kotlin/org/rust/lang/core/types/RsPsiSubstitution.kt
Expand Up @@ -6,6 +6,8 @@
package org.rust.lang.core.types

import org.rust.lang.core.psi.*
import org.rust.lang.core.psi.ext.RsElement
import org.rust.lang.core.psi.ext.isConst
import org.rust.lang.core.types.consts.CtConstParameter
import org.rust.lang.core.types.consts.CtUnknown
import org.rust.lang.core.types.infer.resolve
Expand All @@ -22,7 +24,7 @@ import org.rust.lang.utils.evaluation.evaluate
open class RsPsiSubstitution(
val typeSubst: Map<RsTypeParameter, TypeValue> = emptyMap(),
val regionSubst: Map<RsLifetimeParameter, Value<RsLifetime>> = emptyMap(),
val constSubst: Map<RsConstParameter, Value<RsExpr>> = emptyMap(),
val constSubst: Map<RsConstParameter, Value<RsElement>> = emptyMap(),
val assoc: Map<RsTypeAlias, RsTypeReference> = emptyMap(),
) {
sealed class TypeValue {
Expand Down Expand Up @@ -76,7 +78,22 @@ fun RsPsiSubstitution.toSubst(resolver: PathExprResolver? = PathExprResolver.def
RsPsiSubstitution.Value.RequiredAbsent -> CtUnknown
is RsPsiSubstitution.Value.Present -> {
val expectedTy = param.parameter.typeReference?.type ?: TyUnknown
psiValue.value.evaluate(expectedTy, resolver)
when (val value = psiValue.value) {
is RsExpr -> value.evaluate(expectedTy, resolver)
is RsBaseType -> when (val resolved = value.path?.reference?.resolve()) {
is RsConstParameter -> CtConstParameter(resolved)
is RsConstant -> when {
resolved.isConst -> {
// TODO check types
val type = resolved.typeReference?.type ?: TyUnknown
resolved.expr?.evaluate(type, resolver) ?: CtUnknown
}
else -> CtUnknown
}
else -> CtUnknown
}
else -> CtUnknown
}
}
}

Expand Down
15 changes: 12 additions & 3 deletions src/main/kotlin/org/rust/lang/core/types/infer/TypeInference.kt
Expand Up @@ -15,6 +15,7 @@ import org.rust.lang.core.psi.*
import org.rust.lang.core.psi.ext.*
import org.rust.lang.core.resolve.*
import org.rust.lang.core.resolve.ref.MethodResolveVariant
import org.rust.lang.core.resolve.ref.pathPsiSubst
import org.rust.lang.core.resolve.ref.resolvePathRaw
import org.rust.lang.core.types.*
import org.rust.lang.core.types.consts.*
Expand Down Expand Up @@ -188,9 +189,17 @@ class RsInferenceContext(
else -> null
}
val declaration = path?.let { resolvePathRaw(it, lookup) }?.singleOrNull()?.element as? RsGenericDeclaration
val constParameters = declaration?.constParameters.orEmpty()
val constArguments = path?.constArguments.orEmpty()
RsTypeInferenceWalker(this, TyUnknown).inferConstArgumentTypes(constParameters, constArguments)
if (path != null && declaration != null) {
val constParameters = mutableListOf<RsConstParameter>()
val constArguments = mutableListOf<RsExpr>()
for ((param, value) in pathPsiSubst(path, declaration).constSubst) {
if (value is RsPsiSubstitution.Value.Present && value.value is RsExpr) {
constParameters += param
constArguments += value.value
}
}
RsTypeInferenceWalker(this, TyUnknown).inferConstArgumentTypes(constParameters, constArguments)
}
}
else -> {
val (retTy, expr) = when (element) {
Expand Down
Expand Up @@ -354,4 +354,14 @@ class RsTypeCheckInspectionTest : RsInspectionsTestBase(RsTypeCheckInspection::c
foo(<error>&s</error>);
}
""")

// Issue https://github.com/intellij-rust/intellij-rust/issues/7420
fun `test correctly match const arguments with const parameters`() = checkErrors("""
struct Foo<const N: u8, const S: bool>(u32);
impl<const N: u8> From<u32> for Foo<N, true> {
// ^^^^ the error should not appear here
fn from(val: u32) -> Self { Foo(val) }
}
""")
}

0 comments on commit 49e607b

Please sign in to comment.