Skip to content

Commit

Permalink
TY&RES: implement hygiene & infer types in macros
Browse files Browse the repository at this point in the history
  • Loading branch information
vlad20012 committed Jul 4, 2019
1 parent cadefbc commit c8d6d47
Show file tree
Hide file tree
Showing 17 changed files with 490 additions and 42 deletions.
1 change: 1 addition & 0 deletions src/main/grammars/RustParser.bnf
Expand Up @@ -1404,6 +1404,7 @@ private BlockElement_recover ::= !('}' | Item_first | Expr_first | let | ';')

Stmt ::= LetDecl | EmptyStmt | never ';' {
implements = "org.rust.lang.core.macros.RsExpandedElement"
mixin = "org.rust.lang.core.psi.ext.RsStmtMixin"
}

ExprStmtOrLastExpr ::= StmtModeExpr (ExprStmtUpper | LastExprUpper) {
Expand Down
18 changes: 14 additions & 4 deletions src/main/kotlin/org/rust/lang/core/macros/RsExpandedElement.kt
Expand Up @@ -56,6 +56,9 @@ val RsExpandedElement.expandedFromRecursively: RsMacroCall?
return call
}

val RsExpandedElement.expandedFromSequence: Sequence<RsMacroCall>
get() = generateSequence(expandedFrom) { it.expandedFrom }

fun PsiElement.findMacroCallExpandedFrom(): RsMacroCall? {
val found = findMacroCallExpandedFromNonRecursive()
return found?.findMacroCallExpandedFrom() ?: found
Expand All @@ -72,13 +75,20 @@ val PsiElement.isExpandedFromMacro: Boolean
get() = findMacroCallExpandedFromNonRecursive() != null

/**
* If receiver element is inside a macro expansion, returns the leaf element inside the macro call
* from which the first token of this element is expanded. Always returns an element inside a root
* Same as [findElementExpandedFrom], but always returns an element inside a root
* macro call, i.e. outside of any expansion.
*/
fun PsiElement.findElementExpandedFromChecked(): PsiElement? {
return findElementExpandedFrom()?.takeIf { !it.isExpandedFromMacro }
}

/**
* If receiver element is inside a macro expansion, returns the leaf element inside the macro call
* from which the first token of this element is expanded
*/
fun PsiElement.findElementExpandedFrom(): PsiElement? {
val mappedElement = findElementExpandedFromNonRecursive() ?: return null
return mappedElement.findElementExpandedFrom() ?: mappedElement.takeIf { !it.isExpandedFromMacro }
return mappedElement.findElementExpandedFrom() ?: mappedElement
}

private fun PsiElement.findElementExpandedFromNonRecursive(): PsiElement? {
Expand Down Expand Up @@ -147,7 +157,7 @@ private fun Int.fromBodyRelativeOffset(call: RsMacroCall): Int? {
fun PsiElement.findNavigationTargetIfMacroExpansion(): PsiElement? {
/** @see RsNamedElementImpl.getTextOffset */
val element = (this as? RsNameIdentifierOwner)?.nameIdentifier ?: this
return element.findElementExpandedFrom() ?: findMacroCallExpandedFrom()?.path
return element.findElementExpandedFromChecked() ?: findMacroCallExpandedFrom()?.path
}

private val RS_EXPANSION_CONTEXT = Key.create<RsElement>("org.rust.lang.core.psi.CODE_FRAGMENT_FILE")
Expand Down
@@ -0,0 +1,23 @@
/*
* Use of this source code is governed by the MIT license that can be
* found in the LICENSE file.
*/

package org.rust.lang.core.psi

import com.intellij.psi.PsiElement
import com.intellij.psi.PsiRecursiveElementWalkingVisitor
import org.rust.lang.core.psi.ext.expansion

abstract class RsWithMacroExpansionsRecursiveElementWalkingVisitor : PsiRecursiveElementWalkingVisitor() {
override fun visitElement(element: PsiElement) {
if (element is RsMacroCall && element.macroArgument != null) {
val expansion = element.expansion ?: return
for (expandedElement in expansion.elements) {
visitElement(expandedElement)
}
} else {
super.visitElement(element)
}
}
}
4 changes: 4 additions & 0 deletions src/main/kotlin/org/rust/lang/core/psi/ext/PsiElement.kt
Expand Up @@ -18,6 +18,7 @@ import com.intellij.psi.util.PsiUtilCore
import com.intellij.util.SmartList
import org.rust.lang.core.psi.RsFile
import org.rust.lang.core.stubs.RsFileStub
import org.rust.openapiext.findDescendantsWithMacrosOfAnyType

val PsiFileSystemItem.sourceRoot: VirtualFile?
get() = virtualFile.let { ProjectRootManager.getInstance(project).fileIndex.getSourceRootForFile(it) }
Expand Down Expand Up @@ -190,6 +191,9 @@ fun <T : PsiElement> getStubDescendantOfType(
}
}

inline fun <reified T : PsiElement> PsiElement.descendantsWithMacrosOfType(): Collection<T> =
findDescendantsWithMacrosOfAnyType(this, true, T::class.java)

/**
* Same as [PsiElement.getContainingFile], but return a "fake" file. See [org.rust.lang.core.macros.RsExpandedElement].
*/
Expand Down
45 changes: 45 additions & 0 deletions src/main/kotlin/org/rust/lang/core/psi/ext/RsBlock.kt
@@ -0,0 +1,45 @@
/*
* Use of this source code is governed by the MIT license that can be
* found in the LICENSE file.
*/

package org.rust.lang.core.psi.ext

import org.rust.lang.core.macros.RsExpandedElement
import org.rust.lang.core.psi.RsBlock
import org.rust.lang.core.psi.RsExpr
import org.rust.lang.core.psi.RsMacroCall
import org.rust.lang.core.psi.RsStmt

/**
* Can contain [RsStmt]s and [RsExpr]s (which are equivalent to RsExprStmt(RsExpr))
*/
val RsBlock.expandedStmts: List<RsExpandedElement>
get() {
val stmts = mutableListOf<RsExpandedElement>()
processExpandedStmtsInternal { stmt ->
stmts.add(stmt)
false
}
return stmts
}

private val RsBlock.stmtsAndMacros: Sequence<RsElement>
get() {
val stub = greenStub
return if (stub != null) {
stub.childrenStubs.asSequence().map { it.psi }
} else {
childrenWithLeaves
}.filterIsInstance<RsElement>()
}

private fun RsBlock.processExpandedStmtsInternal(processor: (RsExpandedElement) -> Boolean): Boolean {
return stmtsAndMacros.any { it.processStmt(processor) }
}

private fun RsElement.processStmt(processor: (RsExpandedElement) -> Boolean) = when (this) {
is RsMacroCall -> processExpansionRecursively(processor)
is RsExpandedElement -> processor(this)
else -> false
}
15 changes: 15 additions & 0 deletions src/main/kotlin/org/rust/lang/core/psi/ext/RsStmt.kt
@@ -0,0 +1,15 @@
/*
* Use of this source code is governed by the MIT license that can be
* found in the LICENSE file.
*/

package org.rust.lang.core.psi.ext

import com.intellij.lang.ASTNode
import com.intellij.psi.PsiElement
import org.rust.lang.core.macros.RsExpandedElement
import org.rust.lang.core.psi.RsStmt

abstract class RsStmtMixin(node: ASTNode) : RsElementImpl(node), RsStmt {
override fun getContext(): PsiElement? = RsExpandedElement.getContextImpl(this)
}
46 changes: 38 additions & 8 deletions src/main/kotlin/org/rust/lang/core/resolve/NameResolution.kt
Expand Up @@ -17,7 +17,10 @@ import com.intellij.psi.PsiElement
import com.intellij.psi.PsiFile
import com.intellij.psi.StubBasedPsiElement
import com.intellij.psi.stubs.StubElement
import com.intellij.psi.util.*
import com.intellij.psi.util.CachedValue
import com.intellij.psi.util.CachedValueProvider
import com.intellij.psi.util.CachedValuesManager
import com.intellij.psi.util.PsiTreeUtil
import org.rust.cargo.project.workspace.CargoWorkspace
import org.rust.cargo.project.workspace.PackageOrigin
import org.rust.cargo.util.AutoInjectedCrates.CORE
Expand Down Expand Up @@ -621,8 +624,9 @@ fun processLifetimeResolveVariants(lifetime: RsLifetime, processor: RsResolvePro
}

fun processLocalVariables(place: RsElement, processor: (RsPatBinding) -> Unit) {
val hygieneFilter = makeHygieneFilter(place)
walkUp(place, { it is RsItemElement }) { cameFrom, scope ->
processLexicalDeclarations(scope, cameFrom, VALUES, ItemProcessingMode.WITH_PRIVATE_IMPORTS) { v ->
processLexicalDeclarations(scope, cameFrom, VALUES, hygieneFilter, ItemProcessingMode.WITH_PRIVATE_IMPORTS) { v ->
val el = v.element
if (el is RsPatBinding) processor(el)
false
Expand Down Expand Up @@ -1174,14 +1178,15 @@ private fun processLexicalDeclarations(
scope: RsElement,
cameFrom: PsiElement,
ns: Set<Namespace>,
hygieneFilter: (RsPatBinding) -> Boolean,
ipm: ItemProcessingMode,
processor: RsResolveProcessor
): Boolean {
check(cameFrom.context == scope)

fun processPattern(pattern: RsPat, processor: RsResolveProcessor): Boolean {
val boundNames = PsiTreeUtil.findChildrenOfType(pattern, RsPatBinding::class.java)
.filter { it.reference.resolve() == null }
.filter { it.reference.resolve() == null && hygieneFilter(it) }
return processAll(boundNames, processor)
}

Expand Down Expand Up @@ -1265,10 +1270,14 @@ private fun processLexicalDeclarations(
}
}

for (stmt in scope.stmtList.asReversed()) {
val pat = (stmt as? RsLetDecl)?.pat ?: continue
if (PsiUtilCore.compareElementsByPosition(cameFrom, stmt) < 0) continue
if (stmt == cameFrom) continue
val letDecls = mutableListOf<RsLetDecl>()
for (stmt in scope.expandedStmts) {
if (cameFrom == stmt) break
if (stmt is RsLetDecl) letDecls.add(stmt)
}

for (let in letDecls.asReversed()) {
val pat = let.pat ?: continue
if (processPattern(pat, shadowingProcessor)) return true
}
}
Expand Down Expand Up @@ -1316,6 +1325,11 @@ fun processNestedScopesUpwards(
isCompletion: Boolean,
processor: RsResolveProcessor
): Boolean {
val hygieneFilter: (RsPatBinding) -> Boolean = if (scopeStart is RsPath && ns == VALUES) {
makeHygieneFilter(scopeStart)
} else {
{ true }
}
val prevScope = mutableSetOf<String>()
val stop = walkUp(scopeStart, { it is RsMod }) { cameFrom, scope ->
processWithShadowing(prevScope, processor) { shadowingProcessor ->
Expand All @@ -1324,7 +1338,7 @@ fun processNestedScopesUpwards(
isCompletion -> ItemProcessingMode.WITH_PRIVATE_IMPORTS_N_EXTERN_CRATES_COMPLETION
else -> ItemProcessingMode.WITH_PRIVATE_IMPORTS_N_EXTERN_CRATES
}
processLexicalDeclarations(scope, cameFrom, ns, ipm, shadowingProcessor)
processLexicalDeclarations(scope, cameFrom, ns, hygieneFilter, ipm, shadowingProcessor)
}
}
if (stop) return true
Expand All @@ -1341,6 +1355,21 @@ fun processNestedScopesUpwards(
return false
}

private fun makeHygieneFilter(anchor: PsiElement): (RsPatBinding) -> Boolean {
val hygienicScope = if (!anchor.isExpandedFromMacro) {
anchor.containingFile
} else {
val nameIdentifier = if (anchor is RsReferenceElement) anchor.referenceNameElement else anchor
(nameIdentifier.findElementExpandedFrom() ?: nameIdentifier).containingFile
}

return fun(element: RsPatBinding): Boolean {
val nameIdentifier = element.nameIdentifier ?: return false
val hygienicScope2 = (nameIdentifier.findElementExpandedFrom() ?: nameIdentifier).containingFile
return hygienicScope == hygienicScope2
}
}

inline fun processWithShadowing(
prevScope: MutableSet<String>,
crossinline processor: RsResolveProcessor,
Expand Down Expand Up @@ -1439,3 +1468,4 @@ object NameResolutionTestmarks {
}

private data class ImplicitStdlibCrate(val name: String, val crateRoot: RsFile)

Expand Up @@ -10,6 +10,7 @@ import org.rust.lang.core.macros.findExpansionElements
import org.rust.lang.core.psi.ext.RsElement
import org.rust.lang.core.psi.ext.RsReferenceElementBase
import org.rust.lang.core.psi.ext.ancestors
import org.rust.openapiext.Testmark

class RsMacroBodyReferenceDelegateImpl(
element: RsReferenceElementBase
Expand All @@ -19,6 +20,7 @@ class RsMacroBodyReferenceDelegateImpl(

private val delegates: List<RsReference>
get() {
Testmarks.touched.hit()
return element.findExpansionElements()?.mapNotNull { delegated ->
delegated.ancestors
.mapNotNull { it.reference }
Expand All @@ -32,4 +34,8 @@ class RsMacroBodyReferenceDelegateImpl(

override fun multiResolve(): List<RsElement> =
delegates.flatMap { it.multiResolve() }.distinct()

object Testmarks {
val touched = Testmark("touched")
}
}
Expand Up @@ -9,6 +9,7 @@ import com.intellij.openapi.progress.ProgressManager
import com.intellij.psi.PsiElement
import com.intellij.util.containers.isNullOrEmpty
import org.rust.lang.core.macros.MacroExpansion
import org.rust.lang.core.macros.expandedFromSequence
import org.rust.lang.core.psi.*
import org.rust.lang.core.psi.ext.*
import org.rust.lang.core.resolve.*
Expand Down Expand Up @@ -94,10 +95,25 @@ class RsTypeInferenceWalker(

private fun RsBlock.inferType(expected: Ty? = null, coerce: Boolean = false): Ty {
var isDiverging = false
for (stmt in stmtList) {
isDiverging = processStatement(stmt) || isDiverging
val expandedStmts = expandedStmts
val tailExpr = expandedStmts.lastOrNull()
?.let { it as? RsExpr }
?.takeIf { e ->
e.expandedFromSequence.all {
val bracesKind = it.bracesKind ?: return@all false
!bracesKind.needsSemicolon || it.semicolon == null
}
}
for (stmt in expandedStmts) {
val result = when (stmt) {
tailExpr -> false
is RsStmt -> processStatement(stmt)
is RsExpr -> stmt.inferType() == TyNever
else -> false
}
isDiverging = result || isDiverging
}
val type = (if (coerce) expr?.inferTypeCoercableTo(expected!!) else expr?.inferType(expected)) ?: TyUnit
val type = (if (coerce) tailExpr?.inferTypeCoercableTo(expected!!) else tailExpr?.inferType(expected)) ?: TyUnit
return if (isDiverging) TyNever else type
}

Expand Down Expand Up @@ -1134,7 +1150,7 @@ class RsTypeInferenceWalker(
name == "stringify" -> TyReference(TyStr, Mutability.IMMUTABLE)
name == "module_path" -> TyReference(TyStr, Mutability.IMMUTABLE)
name == "cfg" -> TyBool
else -> TyUnknown
else -> (macroCall.expansion as? MacroExpansion.Expr)?.expr?.inferType() ?: TyUnknown
}
}

Expand Down

0 comments on commit c8d6d47

Please sign in to comment.