Skip to content

Commit

Permalink
Merge #4242
Browse files Browse the repository at this point in the history
4242: TY: Initial type inference for const generics r=vlad20012 a=mchernyavsky

Relates to #3985.
Depends on #4677 and #4782.

Co-authored-by: Mikhail Chernyavsky <mikhail.chernyavsky@jetbrains.com>
Co-authored-by: Mikhail Chernyavsky <Mikhail.Chernyavsky@jetbrains.com>
  • Loading branch information
3 people committed Mar 10, 2020
2 parents 84c2690 + 7c168be commit f0f3387
Show file tree
Hide file tree
Showing 50 changed files with 1,871 additions and 527 deletions.
85 changes: 60 additions & 25 deletions src/main/kotlin/org/rust/ide/annotator/RsErrorAnnotator.kt
Expand Up @@ -31,12 +31,12 @@ import org.rust.lang.core.resolve.knownItems
import org.rust.lang.core.resolve.namespaces
import org.rust.lang.core.resolve.ref.deepResolve
import org.rust.lang.core.types.*
import org.rust.lang.core.types.consts.asLong
import org.rust.lang.core.types.ty.*
import org.rust.lang.utils.RsDiagnostic
import org.rust.lang.utils.RsErrorCode
import org.rust.lang.utils.addToHolder
import org.rust.lang.utils.evaluation.ExprValue
import org.rust.lang.utils.evaluation.RsConstExprEvaluator
import org.rust.lang.utils.evaluation.evaluate

class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension {
override fun isForceHighlightParents(file: PsiFile): Boolean = file is RsFile
Expand Down Expand Up @@ -128,9 +128,15 @@ class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension {
val declarationFieldsAmount = declaration.fields.size
val bodyFieldsAmount = patTupleStruct.patList.size
if (bodyFieldsAmount < declarationFieldsAmount && patTupleStruct.patRest == null) {
RsDiagnostic.MissingFieldsInTuplePattern(patTupleStruct, declaration, declarationFieldsAmount, bodyFieldsAmount).addToHolder(holder)
RsDiagnostic.MissingFieldsInTuplePattern(
patTupleStruct,
declaration,
declarationFieldsAmount,
bodyFieldsAmount
).addToHolder(holder)
} else if (bodyFieldsAmount > declarationFieldsAmount) {
RsDiagnostic.ExtraFieldInTupleStructPattern(patTupleStruct, bodyFieldsAmount, declarationFieldsAmount).addToHolder(holder)
RsDiagnostic.ExtraFieldInTupleStructPattern(patTupleStruct, bodyFieldsAmount, declarationFieldsAmount)
.addToHolder(holder)
}
}

Expand Down Expand Up @@ -205,12 +211,7 @@ class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension {
val indexToVariantMap = hashMapOf<Long, VariantInfo>()
for (variant in o.enumVariantList) {
val expr = variant.variantDiscriminant?.expr
val int = if (expr != null) {
val result = RsConstExprEvaluator.evaluate(expr, reprType) as? ExprValue.Integer
result?.value ?: return
} else {
null
}
val int = if (expr != null) expr.evaluate(reprType).asLong() ?: return else null
val idx = int ?: discrCounter
discrCounter = idx + 1

Expand Down Expand Up @@ -290,19 +291,25 @@ class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension {
val error = when {
element is RsNamedFieldDecl -> {
val structName = element.ancestorStrict<RsStructItem>()?.crateRelativePath?.removePrefix("::") ?: ""
RsDiagnostic.StructFieldAccessError(ref, ref.referenceName, structName,
MakePublicFix.createIfCompatible(element, element.name, withinOneCrate))
RsDiagnostic.StructFieldAccessError(
ref, ref.referenceName, structName,
MakePublicFix.createIfCompatible(element, element.name, withinOneCrate)
)
}
ref is RsMethodCall -> RsDiagnostic.AccessError(ref.identifier, RsErrorCode.E0624, "Method",
MakePublicFix.createIfCompatible(element, ref.referenceName, withinOneCrate))
ref is RsMethodCall -> RsDiagnostic.AccessError(
ref.identifier, RsErrorCode.E0624, "Method",
MakePublicFix.createIfCompatible(element, ref.referenceName, withinOneCrate)
)
else -> {
val itemType = when (element) {
is RsItemElement -> element.itemKindName.capitalize()
else -> "Item"
}

RsDiagnostic.AccessError(ref, RsErrorCode.E0603, itemType,
MakePublicFix.createIfCompatible(element, ref.referenceName, withinOneCrate))
RsDiagnostic.AccessError(
ref, RsErrorCode.E0603, itemType,
MakePublicFix.createIfCompatible(element, ref.referenceName, withinOneCrate)
)
}
}
error.addToHolder(holder)
Expand All @@ -327,7 +334,8 @@ class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension {
if (type.underscore == null) return
val owner = type.owner.parent
if ((owner is RsValueParameter && owner.parent.parent is RsFunction)
|| (owner is RsRetType && owner.parent is RsFunction) || owner is RsConstant) {
|| (owner is RsRetType && owner.parent is RsFunction) || owner is RsConstant
) {
RsDiagnostic.TypePlaceholderForbiddenError(type).addToHolder(holder)
}
}
Expand All @@ -348,7 +356,10 @@ class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension {
private fun checkPath(holder: RsAnnotationHolder, path: RsPath) {
val qualifier = path.path
if ((qualifier == null || isValidSelfSuperPrefix(qualifier)) && !isValidSelfSuperPrefix(path)) {
holder.createErrorAnnotation(path.referenceNameElement, "Invalid path: self and super are allowed only at the beginning")
holder.createErrorAnnotation(
path.referenceNameElement,
"Invalid path: self and super are allowed only at the beginning"
)
return
}

Expand Down Expand Up @@ -539,7 +550,12 @@ class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension {
}

// E0120: Drop can be only implemented by structs and enums
private fun checkImplDropForNonAdtError(holder: RsAnnotationHolder, impl: RsImplItem, traitRef: RsTraitRef, trait: RsTraitItem) {
private fun checkImplDropForNonAdtError(
holder: RsAnnotationHolder,
impl: RsImplItem,
traitRef: RsTraitRef,
trait: RsTraitItem
) {
if (trait != trait.knownItems.Drop) return

if (impl.typeReference?.type is TyAdt?) return
Expand All @@ -559,7 +575,12 @@ class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension {
}

// E0184: Cannot implement both Copy and Drop
private fun checkImplBothCopyAndDrop(holder: RsAnnotationHolder, self: Ty, element: PsiElement, trait: RsTraitItem) {
private fun checkImplBothCopyAndDrop(
holder: RsAnnotationHolder,
self: Ty,
element: PsiElement,
trait: RsTraitItem
) {
val oppositeTrait = when (trait) {
trait.knownItems.Drop -> trait.knownItems.Copy
trait.knownItems.Copy -> trait.knownItems.Drop
Expand Down Expand Up @@ -685,7 +706,11 @@ class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension {
RsDiagnostic.InvalidStartAttrError.InvalidParam(params[0].typeReference ?: params[0], 0)
.addToHolder(holder)
}
if (params[1].typeReference?.type != TyPointer(TyPointer(TyInteger.U8, Mutability.IMMUTABLE), Mutability.IMMUTABLE)) {
if (params[1].typeReference?.type != TyPointer(
TyPointer(TyInteger.U8, Mutability.IMMUTABLE),
Mutability.IMMUTABLE
)
) {
RsDiagnostic.InvalidStartAttrError.InvalidParam(params[1].typeReference ?: params[1], 1)
.addToHolder(holder)
}
Expand Down Expand Up @@ -771,8 +796,10 @@ class RsErrorAnnotator : AnnotatorBase(), HighlightRangeExtension {
val dotdoteq = range.dotdoteq ?: range.dotdotdot ?: return
if (dotdoteq == range.dotdotdot) {
// rustc doesn't have an error code for this ("error: unexpected token: `...`")
holder.createErrorAnnotation(dotdoteq,
"`...` syntax is deprecated. Use `..` for an exclusive range or `..=` for an inclusive range")
holder.createErrorAnnotation(
dotdoteq,
"`...` syntax is deprecated. Use `..` for an exclusive range or `..=` for an inclusive range"
)
return
}
val expr = range.exprList.singleOrNull() ?: return
Expand Down Expand Up @@ -844,7 +871,12 @@ private fun RsExpr?.isComparisonBinaryExpr(): Boolean {
return op is ComparisonOp || op is EqualityOp
}

private fun checkDuplicates(holder: RsAnnotationHolder, element: RsNameIdentifierOwner, scope: PsiElement = element.parent, recursively: Boolean = false) {
private fun checkDuplicates(
holder: RsAnnotationHolder,
element: RsNameIdentifierOwner,
scope: PsiElement = element.parent,
recursively: Boolean = false
) {
if (element.isCfgUnknown) return
val owner = if (scope is RsMembers) scope.parent else scope
val duplicates = holder.currentAnnotationSession.duplicatesByNamespace(scope, recursively)
Expand Down Expand Up @@ -897,7 +929,10 @@ private fun PsiElement.nameOrImportedName(): String? =
else -> null
}

private fun AnnotationSession.duplicatesByNamespace(owner: PsiElement, recursively: Boolean): Map<Namespace, Set<PsiElement>> {
private fun AnnotationSession.duplicatesByNamespace(
owner: PsiElement,
recursively: Boolean
): Map<Namespace, Set<PsiElement>> {
if (owner.parent is RsFnPointerType) return emptyMap()

fun PsiElement.namespaced(): Sequence<Pair<Namespace, PsiElement>> =
Expand Down
Expand Up @@ -13,6 +13,9 @@ import org.rust.lang.core.psi.RsTypeAlias
import org.rust.lang.core.psi.RsTypeParameter
import org.rust.lang.core.psi.ext.typeParameters
import org.rust.lang.core.types.BoundElement
import org.rust.lang.core.types.Kind
import org.rust.lang.core.types.consts.CtConstParameter
import org.rust.lang.core.types.consts.CtValue
import org.rust.lang.core.types.ty.*
import org.rust.lang.core.types.type

Expand All @@ -22,19 +25,22 @@ class RsTypeHintsPresentationFactory(private val factory: PresentationFactory, p
listOf(text(": "), hint(type, 1)).join()
)

private fun hint(type: Ty, level: Int): InlayPresentation = when (type) {
is TyTuple -> tupleTypeHint(type, level)
is TyAdt -> adtTypeHint(type, level)
is TyFunction -> functionTypeHint(type, level)
is TyReference -> referenceTypeHint(type, level)
is TyPointer -> pointerTypeHint(type, level)
is TyProjection -> projectionTypeHint(type, level)
is TyTypeParameter -> typeParameterTypeHint(type)
is TyArray -> arrayTypeHint(type, level)
is TySlice -> sliceTypeHint(type, level)
is TyTraitObject -> traitObjectTypeHint(type, level)
is TyAnon -> anonTypeHint(type, level)
else -> text(type.shortPresentableText)
private fun hint(kind: Kind, level: Int): InlayPresentation = when (kind) {
is TyTuple -> tupleTypeHint(kind, level)
is TyAdt -> adtTypeHint(kind, level)
is TyFunction -> functionTypeHint(kind, level)
is TyReference -> referenceTypeHint(kind, level)
is TyPointer -> pointerTypeHint(kind, level)
is TyProjection -> projectionTypeHint(kind, level)
is TyTypeParameter -> typeParameterTypeHint(kind)
is TyArray -> arrayTypeHint(kind, level)
is TySlice -> sliceTypeHint(kind, level)
is TyTraitObject -> traitObjectTypeHint(kind, level)
is TyAnon -> anonTypeHint(kind, level)
is CtConstParameter -> constParameterTypeHint(kind)
is CtValue -> text(kind.expr.toString())
is Ty -> text(kind.shortPresentableText)
else -> text(null)
}

private fun functionTypeHint(type: TyFunction, level: Int): InlayPresentation {
Expand Down Expand Up @@ -88,22 +94,25 @@ class RsTypeHintsPresentationFactory(private val factory: PresentationFactory, p
val typeArguments = alias?.typeParameters?.map { (aliasedBy.subst[it] ?: TyUnknown) to it }
?: type.typeArguments.zip(type.item.typeParameters)

val userVisibleTypeArguments = mutableListOf<Ty>()
val userVisibleKindArguments = mutableListOf<Kind>()
for ((argument, parameter) in typeArguments) {
if (!showObviousTypes && isDefaultTypeParameter(argument, parameter)) {
// don't show default types
continue
}
userVisibleTypeArguments.add(argument)
userVisibleKindArguments.add(argument)
}
for (argument in type.constArguments) {
userVisibleKindArguments.add(argument)
}

if (userVisibleTypeArguments.isNotEmpty()) {
if (userVisibleKindArguments.isNotEmpty()) {
val collapsible = factory.collapsible(
prefix = text("<"),
collapsed = text(PLACEHOLDER),
expanded = { parametersHint(userVisibleTypeArguments, level + 1) },
expanded = { parametersHint(userVisibleKindArguments, level + 1) },
suffix = text(">"),
startWithPlaceholder = checkSize(level, userVisibleTypeArguments.size)
startWithPlaceholder = checkSize(level, userVisibleKindArguments.size)
)
return listOf(typeNamePresentation, collapsible).join()
}
Expand Down Expand Up @@ -147,6 +156,9 @@ class RsTypeHintsPresentationFactory(private val factory: PresentationFactory, p
return text(parameter.name)
}

private fun constParameterTypeHint(const: CtConstParameter): InlayPresentation =
factory.psiSingleReference(text(const.parameter.name)) { const.parameter }

private fun arrayTypeHint(type: TyArray, level: Int): InlayPresentation =
factory.collapsible(
prefix = text("["),
Expand Down Expand Up @@ -187,8 +199,8 @@ class RsTypeHintsPresentationFactory(private val factory: PresentationFactory, p
startWithPlaceholder = checkSize(level, type.traits.size)
)

private fun parametersHint(types: List<Ty>, level: Int): InlayPresentation =
types.map { hint(it, level) }.join(", ")
private fun parametersHint(kinds: List<Kind>, level: Int): InlayPresentation =
kinds.map { hint(it, level) }.join(", ")

private fun traitItemTypeHint(
trait: BoundElement<RsTraitItem>,
Expand Down
Expand Up @@ -7,12 +7,13 @@ package org.rust.ide.inspections.checkMatch

import org.rust.lang.core.psi.*
import org.rust.lang.core.psi.ext.*
import org.rust.lang.core.types.consts.CtValue
import org.rust.lang.core.types.ty.Ty
import org.rust.lang.core.types.ty.TyAdt
import org.rust.lang.core.types.ty.TyUnknown
import org.rust.lang.core.types.type
import org.rust.lang.utils.evaluation.ExprValue
import org.rust.lang.utils.evaluation.RsConstExprEvaluator
import org.rust.lang.utils.evaluation.ConstExpr.Value
import org.rust.lang.utils.evaluation.evaluate

class CheckMatchException(message: String) : Exception(message)

Expand Down Expand Up @@ -44,7 +45,8 @@ val Matrix.firstColumnType: Ty
fun List<RsMatchArm>.calculateMatrix(): Matrix =
flatMap { arm -> arm.patList.map { listOf(it.lower) } }

private val RsExpr.value: ExprValue? get() = RsConstExprEvaluator.evaluate(this)
private val RsExpr.value: Value<*>?
get() = (evaluate() as? CtValue)?.expr

// lower_pattern_unadjusted
private val RsPat.kind: PatternKind
Expand Down
29 changes: 15 additions & 14 deletions src/main/kotlin/org/rust/ide/inspections/checkMatch/Constructor.kt
Expand Up @@ -13,21 +13,21 @@ import org.rust.lang.core.psi.ext.fieldTypes
import org.rust.lang.core.psi.ext.size
import org.rust.lang.core.psi.ext.variants
import org.rust.lang.core.types.ty.*
import org.rust.lang.utils.evaluation.ExprValue
import org.rust.lang.utils.evaluation.ConstExpr.Value

sealed class Constructor {

/** The constructor of all patterns that don't vary by constructor, e.g. struct patterns and fixed-length arrays */
object Single : Constructor() {
override fun coveredByRange(from: ExprValue, to: ExprValue, included: Boolean): Boolean = true
override fun coveredByRange(from: Value<*>, to: Value<*>, included: Boolean): Boolean = true
}

/** Enum variants */
data class Variant(val variant: RsEnumVariant) : Constructor()

/** Literal values */
data class ConstantValue(val value: ExprValue) : Constructor() {
override fun coveredByRange(from: ExprValue, to: ExprValue, included: Boolean): Boolean =
data class ConstantValue(val value: Value<*>) : Constructor() {
override fun coveredByRange(from: Value<*>, to: Value<*>, included: Boolean): Boolean =
if (included) {
value >= from && value <= to
} else {
Expand All @@ -36,8 +36,9 @@ sealed class Constructor {
}

/** Ranges of literal values (`2..=5` and `2..5`) */
data class ConstantRange(val start: ExprValue, val end: ExprValue, val includeEnd: Boolean = false) : Constructor() {
override fun coveredByRange(from: ExprValue, to: ExprValue, included: Boolean): Boolean =
data class ConstantRange(val start: Value<*>, val end: Value<*>, val includeEnd: Boolean = false) :
Constructor() {
override fun coveredByRange(from: Value<*>, to: Value<*>, included: Boolean): Boolean =
if (includeEnd) {
((end < to) || (included && to == end)) && (start >= from)
} else {
Expand Down Expand Up @@ -68,7 +69,7 @@ sealed class Constructor {
else -> 0
}

open fun coveredByRange(from: ExprValue, to: ExprValue, included: Boolean): Boolean = false
open fun coveredByRange(from: Value<*>, to: Value<*>, included: Boolean): Boolean = false

fun subTypes(type: Ty): List<Ty> = when (type) {
is TyTuple -> type.types
Expand All @@ -94,7 +95,7 @@ sealed class Constructor {
companion object {
fun allConstructors(ty: Ty): List<Constructor> =
when {
ty is TyBool -> listOf(true, false).map { ConstantValue(ExprValue.Bool(it)) }
ty is TyBool -> listOf(true, false).map { ConstantValue(Value.Bool(it)) }

ty is TyAdt && ty.item is RsEnumItem -> ty.item.variants.map { Variant(it) }

Expand All @@ -107,13 +108,13 @@ sealed class Constructor {
}
}

private operator fun ExprValue.compareTo(other: ExprValue): Int {
private operator fun Value<*>.compareTo(other: Value<*>): Int {
return when {
this is ExprValue.Bool && other is ExprValue.Bool -> value.compareTo(other.value)
this is ExprValue.Integer && other is ExprValue.Integer -> value.compareTo(other.value)
this is ExprValue.Float && other is ExprValue.Float -> value.compareTo(other.value)
this is ExprValue.Str && other is ExprValue.Str -> value.compareTo(other.value)
this is ExprValue.Char && other is ExprValue.Char -> value.compareTo(other.value)
this is Value.Bool && other is Value.Bool -> value.compareTo(other.value)
this is Value.Integer && other is Value.Integer -> value.compareTo(other.value)
this is Value.Float && other is Value.Float -> value.compareTo(other.value)
this is Value.Str && other is Value.Str -> value.compareTo(other.value)
this is Value.Char && other is Value.Char -> value.compareTo(other.value)
else -> throw CheckMatchException("Comparison of incompatible types: $javaClass and ${other.javaClass}")
}
}

0 comments on commit f0f3387

Please sign in to comment.