Skip to content

Commit

Permalink
Merge #3996 #4094
Browse files Browse the repository at this point in the history
3996: INT & ANN: Implement intention and fix for struct fields expansion instead of `..` r=vlad20012 a=shevtsiv

<!--
Hello and thank you for the pull request!

We don't have any strict rules about pull requests, but you might check
https://github.com/intellij-rust/intellij-rust/blob/master/CONTRIBUTING.md
for some hints!

Note that we need an electronic CLA for contributions:
https://github.com/intellij-rust/intellij-rust/blob/master/CONTRIBUTING.md#cla

After you sign the CLA, please add your name to
https://github.com/intellij-rust/intellij-rust/blob/master/CONTRIBUTORS.txt

:)
-->
In this pull request I intend to do the following:
1) Create E0023 and E0027 error classes and implement analysis for them.
2) Create fix for those errors.
3) Create intention for struct fields expansion.
4) Create a module with shared code for using it in the fix and intention.

Once these steps are complete I will squash all the commits and #3928 should be closed after merge.

4094: TY&RES: implement hygiene & infer types in macros r=undin a=vlad20012

Now we take into account any custom macros in type inference

Co-authored-by: shevtsiv <rostykshevtsiv@gmail.com>
Co-authored-by: vlad20012 <beskvlad@gmail.com>
  • Loading branch information
3 people committed Aug 19, 2019
3 parents b74af9c + b0eac1c + 3246664 commit 472af45
Show file tree
Hide file tree
Showing 30 changed files with 1,649 additions and 41 deletions.
1 change: 1 addition & 0 deletions src/main/grammars/RustParser.bnf
Expand Up @@ -1449,6 +1449,7 @@ private BlockElement_recover ::= !('}' | Item_first | Expr_first | let | ';')

Stmt ::= LetDecl | EmptyStmt | never ';' {
implements = "org.rust.lang.core.macros.RsExpandedElement"
mixin = "org.rust.lang.core.psi.ext.RsStmtMixin"
}

ExprStmtOrLastExpr ::= StmtModeExpr (ExprStmtUpper | LastExprUpper) {
Expand Down
38 changes: 38 additions & 0 deletions src/main/kotlin/org/rust/ide/annotator/RsErrorAnnotator.kt
Expand Up @@ -85,11 +85,49 @@ class RsErrorAnnotator : RsAnnotatorBase(), HighlightRangeExtension {
override fun visitSelfParameter(o: RsSelfParameter) = checkParamAttrs(holder, o)
override fun visitValueParameter(o: RsValueParameter) = checkParamAttrs(holder, o)
override fun visitVariadic(o: RsVariadic) = checkParamAttrs(holder, o)
override fun visitPatStruct(o: RsPatStruct) = checkRsPatStruct(holder, o)
override fun visitPatTupleStruct(o: RsPatTupleStruct) = checkRsPatTupleStruct(holder, o)
}

element.accept(visitor)
}

private fun checkRsPatStruct(holder: AnnotationHolder, patStruct: RsPatStruct) {
val declaration = patStruct.path.reference.deepResolve() as? RsFieldsOwner ?: return
val declarationFieldNames = declaration.fields.map { it.name }
val bodyFields = patStruct.patFieldList
val extraFields = bodyFields.filter { it.kind.fieldName !in declarationFieldNames }
val bodyFieldNames = bodyFields.map { it.kind.fieldName }
val missingFields = declaration.fields.filter { it.name !in bodyFieldNames && !it.queryAttributes.hasCfgAttr() }
extraFields.forEach {
RsDiagnostic.ExtraFieldInStructPattern(it).addToHolder(holder)
}
if (missingFields.isNotEmpty() && patStruct.dotdot == null) {
if (declaration.elementType == RsElementTypes.ENUM_VARIANT) {
RsDiagnostic.MissingFieldsInEnumVariantPattern(patStruct, declaration.text).addToHolder(holder)
} else {
RsDiagnostic.MissingFieldsInStructPattern(patStruct, declaration.text).addToHolder(holder)
}
}
}

private fun checkRsPatTupleStruct(holder: AnnotationHolder, patTupleStruct: RsPatTupleStruct) {
val declaration = patTupleStruct.path.reference.deepResolve() as? RsFieldsOwner ?: return
val bodyFields = patTupleStruct.childrenOfType<RsPatIdent>()
if (bodyFields.size < declaration.fields.size && patTupleStruct.dotdot == null) {
if (declaration.elementType == RsElementTypes.ENUM_VARIANT) {
RsDiagnostic.MissingFieldsInEnumVariantTuplePattern(patTupleStruct, declaration.text).addToHolder(holder)
} else {
RsDiagnostic.MissingFieldsInTupleStructPattern(patTupleStruct, declaration.text).addToHolder(holder)
}
} else if (bodyFields.size > declaration.fields.size) {
RsDiagnostic.ExtraFieldInTupleStructPattern(
patTupleStruct,
bodyFields.size ,
declaration.fields.size
).addToHolder(holder)
}
}

private fun checkTraitType(holder: AnnotationHolder, traitType: RsTraitType) {
if (!traitType.isImpl) return
Expand Down
@@ -0,0 +1,34 @@
/*
* Use of this source code is governed by the MIT license that can be
* found in the LICENSE file.
*/

package org.rust.ide.annotator.fixes

import com.intellij.codeInspection.LocalQuickFixAndIntentionActionOnPsiElement
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiFile
import org.rust.ide.utils.expandStructFields
import org.rust.ide.utils.expandTupleStructFields
import org.rust.lang.core.psi.RsPatStruct
import org.rust.lang.core.psi.RsPatTupleStruct
import org.rust.lang.core.psi.RsPsiFactory

class AddStructFieldsPatFix(
element: PsiElement
) : LocalQuickFixAndIntentionActionOnPsiElement(element) {
override fun getText() = "Add missing fields"

override fun getFamilyName() = text

override fun invoke(project: Project, file: PsiFile, editor: Editor?, pat: PsiElement, endElement: PsiElement) {
val factory = RsPsiFactory(project)
if (pat is RsPatStruct) {
expandStructFields(factory, pat)
} else if (pat is RsPatTupleStruct) {
expandTupleStructFields(factory, editor, pat)
}
}
}
@@ -0,0 +1,43 @@
/*
* Use of this source code is governed by the MIT license that can be
* found in the LICENSE file.
*/

package org.rust.ide.intentions

import com.intellij.openapi.editor.Editor
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiElement
import org.rust.ide.utils.expandStructFields
import org.rust.ide.utils.expandTupleStructFields
import org.rust.lang.core.psi.*
import org.rust.lang.core.psi.ext.elementType

class AddStructFieldsPatIntention : RsElementBaseIntentionAction<AddStructFieldsPatIntention.Context>() {
override fun getText() = "Replace .. with actual fields"

override fun getFamilyName() = text

data class Context(
val structBody: RsPat
)

override fun findApplicableContext(project: Project, editor: Editor, element: PsiElement): Context? {
return if (element.elementType == RsElementTypes.DOTDOT
&& (element.context is RsPatStruct || element.context is RsPatTupleStruct)) {
Context(element.context as RsPat)
} else {
null
}
}

override fun invoke(project: Project, editor: Editor, ctx: Context) {
val factory = RsPsiFactory(project)
val structBody = ctx.structBody
if (structBody is RsPatStruct) {
expandStructFields(factory, structBody)
} else if (structBody is RsPatTupleStruct) {
expandTupleStructFields(factory, editor, structBody)
}
}
}
101 changes: 101 additions & 0 deletions src/main/kotlin/org/rust/ide/utils/StructFieldsExpander.kt
@@ -0,0 +1,101 @@
/*
* Use of this source code is governed by the MIT license that can be
* found in the LICENSE file.
*/

package org.rust.ide.utils

import com.intellij.openapi.editor.Editor
import com.intellij.psi.PsiElement
import com.intellij.psi.impl.source.tree.LeafPsiElement
import org.rust.lang.core.psi.*
import org.rust.lang.core.psi.ext.*
import org.rust.lang.core.resolve.ref.deepResolve
import org.rust.openapiext.buildAndRunTemplate
import org.rust.openapiext.createSmartPointer

fun expandStructFields(factory: RsPsiFactory, patStruct: RsPatStruct) {
val declaration = patStruct.path.reference.deepResolve() as? RsFieldsOwner ?: return
val hasTrailingComma = patStruct.rbrace.getPrevNonCommentSibling()?.elementType == RsElementTypes.COMMA
patStruct.dotdot?.delete()
val existingFields = patStruct.patFieldList
val bodyFieldNames = existingFields.map { it.kind.fieldName }.toSet()
val missingFields = declaration.fields
.filter { it.name !in bodyFieldNames && !it.queryAttributes.hasCfgAttr() }
.map { factory.createPatField(it.name!!) }

if (existingFields.isEmpty()) {
addFieldsToPat(factory, patStruct, missingFields, hasTrailingComma)
return
}

val fieldPositions = declaration.fields.withIndex().associate { it.value.name!! to it.index }
var insertedFieldsAmount = 0
for (missingField in missingFields) {
val missingFieldPosition = fieldPositions[missingField.kind.fieldName]!!
for (existingField in existingFields) {
val existingFieldPosition = fieldPositions[existingField.kind.fieldName] ?: continue
if (missingFieldPosition < existingFieldPosition) {
patStruct.addBefore(missingField, existingField)
patStruct.addAfter(factory.createComma(), existingField.getPrevNonCommentSibling())
insertedFieldsAmount++
break
}
}
}
addFieldsToPat(factory, patStruct, missingFields.drop(insertedFieldsAmount), hasTrailingComma)
}

fun expandTupleStructFields(factory: RsPsiFactory, editor: Editor?, patTuple: RsPatTupleStruct) {
val declaration = patTuple.path.reference.deepResolve() as? RsFieldsOwner ?: return
val hasTrailingComma = patTuple.rparen.getPrevNonCommentSibling()?.elementType == RsElementTypes.COMMA
val bodyFields = patTuple.childrenOfType<RsPatIdent>()
val missingFieldsAmount = declaration.fields.size - bodyFields.size
addFieldsToPat(factory, patTuple, createTupleStructMissingFields(factory, missingFieldsAmount), hasTrailingComma)
patTuple.dotdot?.delete()
editor?.buildAndRunTemplate(patTuple, patTuple.childrenOfType<RsPatBinding>().map { it.createSmartPointer() })
}

private fun createTupleStructMissingFields(factory: RsPsiFactory, amount: Int): List<RsPatBinding> {
val missingFields = ArrayList<RsPatBinding>(amount)
for (i in 0 until amount) {
missingFields.add(factory.createPatBinding("_$i"))
}
return missingFields
}

private fun addFieldsToPat(factory: RsPsiFactory, pat: RsPat, fields: List<PsiElement>, hasTrailingComma: Boolean) {
var anchor = determineOrCreateAnchor(factory, pat)
for (missingField in fields) {
pat.addAfter(missingField, anchor)
// Do not insert comma if we are in the middle of pattern
// since it will cause double comma in patterns with a trailing comma.
if (fields.last() == missingField) {
if (anchor.nextSibling?.getNextNonCommentSibling()?.elementType != RsElementTypes.DOTDOT) {
pat.addAfter(factory.createComma(), anchor.nextSibling)
}
} else {
pat.addAfter(factory.createComma(), anchor.nextSibling)
}
anchor = anchor.nextSibling.nextSibling as LeafPsiElement
}
if (!hasTrailingComma) {
anchor.delete()
}
}

private fun determineOrCreateAnchor(factory: RsPsiFactory, pat: RsPat): PsiElement {
val dots = pat.childrenOfType<LeafPsiElement>().firstOrNull { it.elementType == RsElementTypes.DOTDOT }
if (dots != null) {
// Picking dots prev sibling as anchor allows as to fill the pattern starting from dots position
// instead of filling pattern starting from the end.
return dots.getPrevNonCommentSibling()!! as LeafPsiElement
}
val lastElementInBody = pat.lastChild.getPrevNonCommentSibling()!!
return if (lastElementInBody !is LeafPsiElement) {
pat.addAfter(factory.createComma(), lastElementInBody)
lastElementInBody.nextSibling
} else {
lastElementInBody
}
}
66 changes: 60 additions & 6 deletions src/main/kotlin/org/rust/lang/core/macros/RsExpandedElement.kt
Expand Up @@ -56,6 +56,9 @@ val RsExpandedElement.expandedFromRecursively: RsMacroCall?
return call
}

val RsExpandedElement.expandedFromSequence: Sequence<RsMacroCall>
get() = generateSequence(expandedFrom) { it.expandedFrom }

fun PsiElement.findMacroCallExpandedFrom(): RsMacroCall? {
val found = findMacroCallExpandedFromNonRecursive()
return found?.findMacroCallExpandedFrom() ?: found
Expand All @@ -72,13 +75,64 @@ val PsiElement.isExpandedFromMacro: Boolean
get() = findMacroCallExpandedFromNonRecursive() != null

/**
* If receiver element is inside a macro expansion, returns the leaf element inside the macro call
* from which the first token of this element is expanded. Always returns an element inside a root
* macro call, i.e. outside of any expansion.
* If [this] is inside a **macro expansion**, returns a leaf element inside a macro call from which
* the first token of this element is expanded. Returns null if [this] element is not inside a
* macro expansion or source element is not a part of a macro call (i.e. is a part of a macro
* definition)
*
* If [strict] is `true`, always returns an element inside a root macro call, i.e. outside of any
* expansion, or null otherwise.
*
* # Examples
*
* ```rust
* macro_rules! foo {
* ($i:ident) => { struct $i; }
* }
* // Source code // Expansion
* foo!(Bar); // struct Bar;
* // \____________________/
* //^ For this element returns `Bar` element in the macro call.
* // It works the same regardless the [strict] value
* ```
*
* ```rust
* macro_rules! foo {
* ($i:item) => { $i }
* }
* macro_rules! bar {
* ($i:ident) => { struct $i; }
* }
* // Source code // Expansion step 1 // Expansion step 2
* foo! { bar!(Baz); } // bar!(Baz); // struct Baz;
* // \_______________________________________/
* //^ For this element returns `Baz` element in the macro call.
* // It works the same regardless the [strict] value
* ```
*
* ```rust
* macro_rules! foo {
* () => { bar!(Baz); }
* }
* macro_rules! bar {
* ($i:ident) => { struct $i; }
* }
* // Source code // Expansion step 1 // Expansion step 2
* foo!(); // bar!(Baz); // struct Baz;
* // \______________________/
* //^ For this element returns `Baz` element in the intermediate
* // macro call ONLY if the [strict] value is `false`.
* // Returns null otherwise
* ```
*/
fun PsiElement.findElementExpandedFrom(): PsiElement? {
fun PsiElement.findElementExpandedFrom(strict: Boolean = true): PsiElement? {
val expandedFrom = findElementExpandedFromUnchecked()
return if (strict) expandedFrom?.takeIf { !it.isExpandedFromMacro } else expandedFrom
}

private fun PsiElement.findElementExpandedFromUnchecked(): PsiElement? {
val mappedElement = findElementExpandedFromNonRecursive() ?: return null
return mappedElement.findElementExpandedFrom() ?: mappedElement.takeIf { !it.isExpandedFromMacro }
return mappedElement.findElementExpandedFromUnchecked() ?: mappedElement
}

private fun PsiElement.findElementExpandedFromNonRecursive(): PsiElement? {
Expand All @@ -96,7 +150,7 @@ private fun mapOffsetFromExpansionToCallBody(call: RsMacroCall, offset: Int): In
}

/**
* If [this] element is inside a macro call body and this macro is successfully expanded, returns
* If [this] element is inside a **macro call** body and this macro is successfully expanded, returns
* a leaf element inside the macro expansion that is expanded from [this] element. Returns a
* list of elements because an element inside a macro call body can be placed in a macro expansion
* multiple times. Returns null if [this] element is not inside a macro call body, or the macro
Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/org/rust/lang/core/psi/RsCodeFragment.kt
Expand Up @@ -77,7 +77,7 @@ abstract class RsCodeFragment(

override fun getForcedResolveScope(): GlobalSearchScope? = forcedResolveScope

override fun getContext(): PsiElement? = context
override fun getContext(): PsiElement = context

final override fun getViewProvider(): SingleRootFileViewProvider = viewProvider

Expand Down
8 changes: 8 additions & 0 deletions src/main/kotlin/org/rust/lang/core/psi/RsPsiFactory.kt
Expand Up @@ -387,6 +387,14 @@ class RsPsiFactory(
?.firstChild as RsPatBinding?
?: error("Failed to create pat element")

fun createPatField(name: String): RsPatField =
createFromText("""
struct Foo { bar: i32 }
fn baz(foo: Foo) {
let Foo { $name } = foo;
}
""") ?: error("Failed to create pat field")

fun createPatStruct(struct: RsStructItem): RsPatStruct {
val structName = struct.name ?: error("Failed to create pat struct")
val pad = if (struct.namedFields.isEmpty()) "" else " "
Expand Down
@@ -0,0 +1,23 @@
/*
* Use of this source code is governed by the MIT license that can be
* found in the LICENSE file.
*/

package org.rust.lang.core.psi

import com.intellij.psi.PsiElement
import com.intellij.psi.PsiRecursiveElementWalkingVisitor
import org.rust.lang.core.psi.ext.expansion

abstract class RsWithMacroExpansionsRecursiveElementWalkingVisitor : PsiRecursiveElementWalkingVisitor() {
override fun visitElement(element: PsiElement) {
if (element is RsMacroCall && element.macroArgument != null) {
val expansion = element.expansion ?: return
for (expandedElement in expansion.elements) {
visitElement(expandedElement)
}
} else {
super.visitElement(element)
}
}
}
4 changes: 4 additions & 0 deletions src/main/kotlin/org/rust/lang/core/psi/ext/PsiElement.kt
Expand Up @@ -18,6 +18,7 @@ import com.intellij.psi.util.PsiUtilCore
import com.intellij.util.SmartList
import org.rust.lang.core.psi.RsFile
import org.rust.lang.core.stubs.RsFileStub
import org.rust.openapiext.findDescendantsWithMacrosOfAnyType

val PsiFileSystemItem.sourceRoot: VirtualFile?
get() = virtualFile.let { ProjectRootManager.getInstance(project).fileIndex.getSourceRootForFile(it) }
Expand Down Expand Up @@ -193,6 +194,9 @@ fun <T : PsiElement> getStubDescendantOfType(
}
}

inline fun <reified T : PsiElement> PsiElement.descendantsWithMacrosOfType(): Collection<T> =
findDescendantsWithMacrosOfAnyType(this, true, T::class.java)

/**
* Same as [PsiElement.getContainingFile], but return a "fake" file. See [org.rust.lang.core.macros.RsExpandedElement].
*/
Expand Down

0 comments on commit 472af45

Please sign in to comment.