From c8d6d47317ccd6db0039f9f4fb8bec9afcad320d Mon Sep 17 00:00:00 2001 From: vlad20012 Date: Mon, 24 Jun 2019 12:52:14 +0300 Subject: [PATCH] TY&RES: implement hygiene & infer types in macros --- src/main/grammars/RustParser.bnf | 1 + .../lang/core/macros/RsExpandedElement.kt | 18 ++- ...xpansionsRecursiveElementWalkingVisitor.kt | 23 ++++ .../org/rust/lang/core/psi/ext/PsiElement.kt | 4 + .../org/rust/lang/core/psi/ext/RsBlock.kt | 45 +++++++ .../org/rust/lang/core/psi/ext/RsStmt.kt | 15 +++ .../rust/lang/core/resolve/NameResolution.kt | 46 +++++-- .../ref/RsMacroBodyReferenceDelegateImpl.kt | 6 + .../core/types/infer/TypeInferenceWalker.kt | 24 +++- src/main/kotlin/org/rust/openapiext/psi.kt | 60 +++++++++ src/test/kotlin/org/rust/RsTestBase.kt | 46 +++++-- .../RsMacroCallReferenceDelegationTest.kt | 22 ++-- .../RsMacroExpansionRangeMappingTest.kt | 4 +- .../resolve/RsMacroExpansionResolveTest.kt | 86 +++++++++++++ .../lang/core/resolve/RsResolveTestBase.kt | 8 +- .../core/type/RsMacroTypeInferenceTest.kt | 116 ++++++++++++++++++ .../lang/core/type/RsTypificationTestBase.kt | 8 +- 17 files changed, 490 insertions(+), 42 deletions(-) create mode 100644 src/main/kotlin/org/rust/lang/core/psi/RsWithMacroExpansionsRecursiveElementWalkingVisitor.kt create mode 100644 src/main/kotlin/org/rust/lang/core/psi/ext/RsBlock.kt create mode 100644 src/main/kotlin/org/rust/lang/core/psi/ext/RsStmt.kt create mode 100644 src/test/kotlin/org/rust/lang/core/type/RsMacroTypeInferenceTest.kt diff --git a/src/main/grammars/RustParser.bnf b/src/main/grammars/RustParser.bnf index acb99ac08d9..2422ea9d122 100644 --- a/src/main/grammars/RustParser.bnf +++ b/src/main/grammars/RustParser.bnf @@ -1404,6 +1404,7 @@ private BlockElement_recover ::= !('}' | Item_first | Expr_first | let | ';') Stmt ::= LetDecl | EmptyStmt | never ';' { implements = "org.rust.lang.core.macros.RsExpandedElement" + mixin = "org.rust.lang.core.psi.ext.RsStmtMixin" } ExprStmtOrLastExpr ::= StmtModeExpr (ExprStmtUpper | LastExprUpper) { diff --git a/src/main/kotlin/org/rust/lang/core/macros/RsExpandedElement.kt b/src/main/kotlin/org/rust/lang/core/macros/RsExpandedElement.kt index 46cdb2995e7..a1e94998eca 100644 --- a/src/main/kotlin/org/rust/lang/core/macros/RsExpandedElement.kt +++ b/src/main/kotlin/org/rust/lang/core/macros/RsExpandedElement.kt @@ -56,6 +56,9 @@ val RsExpandedElement.expandedFromRecursively: RsMacroCall? return call } +val RsExpandedElement.expandedFromSequence: Sequence + get() = generateSequence(expandedFrom) { it.expandedFrom } + fun PsiElement.findMacroCallExpandedFrom(): RsMacroCall? { val found = findMacroCallExpandedFromNonRecursive() return found?.findMacroCallExpandedFrom() ?: found @@ -72,13 +75,20 @@ val PsiElement.isExpandedFromMacro: Boolean get() = findMacroCallExpandedFromNonRecursive() != null /** - * If receiver element is inside a macro expansion, returns the leaf element inside the macro call - * from which the first token of this element is expanded. Always returns an element inside a root + * Same as [findElementExpandedFrom], but always returns an element inside a root * macro call, i.e. outside of any expansion. */ +fun PsiElement.findElementExpandedFromChecked(): PsiElement? { + return findElementExpandedFrom()?.takeIf { !it.isExpandedFromMacro } +} + +/** + * If receiver element is inside a macro expansion, returns the leaf element inside the macro call + * from which the first token of this element is expanded + */ fun PsiElement.findElementExpandedFrom(): PsiElement? { val mappedElement = findElementExpandedFromNonRecursive() ?: return null - return mappedElement.findElementExpandedFrom() ?: mappedElement.takeIf { !it.isExpandedFromMacro } + return mappedElement.findElementExpandedFrom() ?: mappedElement } private fun PsiElement.findElementExpandedFromNonRecursive(): PsiElement? { @@ -147,7 +157,7 @@ private fun Int.fromBodyRelativeOffset(call: RsMacroCall): Int? { fun PsiElement.findNavigationTargetIfMacroExpansion(): PsiElement? { /** @see RsNamedElementImpl.getTextOffset */ val element = (this as? RsNameIdentifierOwner)?.nameIdentifier ?: this - return element.findElementExpandedFrom() ?: findMacroCallExpandedFrom()?.path + return element.findElementExpandedFromChecked() ?: findMacroCallExpandedFrom()?.path } private val RS_EXPANSION_CONTEXT = Key.create("org.rust.lang.core.psi.CODE_FRAGMENT_FILE") diff --git a/src/main/kotlin/org/rust/lang/core/psi/RsWithMacroExpansionsRecursiveElementWalkingVisitor.kt b/src/main/kotlin/org/rust/lang/core/psi/RsWithMacroExpansionsRecursiveElementWalkingVisitor.kt new file mode 100644 index 00000000000..42338c9a79a --- /dev/null +++ b/src/main/kotlin/org/rust/lang/core/psi/RsWithMacroExpansionsRecursiveElementWalkingVisitor.kt @@ -0,0 +1,23 @@ +/* + * Use of this source code is governed by the MIT license that can be + * found in the LICENSE file. + */ + +package org.rust.lang.core.psi + +import com.intellij.psi.PsiElement +import com.intellij.psi.PsiRecursiveElementWalkingVisitor +import org.rust.lang.core.psi.ext.expansion + +abstract class RsWithMacroExpansionsRecursiveElementWalkingVisitor : PsiRecursiveElementWalkingVisitor() { + override fun visitElement(element: PsiElement) { + if (element is RsMacroCall && element.macroArgument != null) { + val expansion = element.expansion ?: return + for (expandedElement in expansion.elements) { + visitElement(expandedElement) + } + } else { + super.visitElement(element) + } + } +} diff --git a/src/main/kotlin/org/rust/lang/core/psi/ext/PsiElement.kt b/src/main/kotlin/org/rust/lang/core/psi/ext/PsiElement.kt index 153516cdafa..a9ce7e66728 100644 --- a/src/main/kotlin/org/rust/lang/core/psi/ext/PsiElement.kt +++ b/src/main/kotlin/org/rust/lang/core/psi/ext/PsiElement.kt @@ -18,6 +18,7 @@ import com.intellij.psi.util.PsiUtilCore import com.intellij.util.SmartList import org.rust.lang.core.psi.RsFile import org.rust.lang.core.stubs.RsFileStub +import org.rust.openapiext.findDescendantsWithMacrosOfAnyType val PsiFileSystemItem.sourceRoot: VirtualFile? get() = virtualFile.let { ProjectRootManager.getInstance(project).fileIndex.getSourceRootForFile(it) } @@ -190,6 +191,9 @@ fun getStubDescendantOfType( } } +inline fun PsiElement.descendantsWithMacrosOfType(): Collection = + findDescendantsWithMacrosOfAnyType(this, true, T::class.java) + /** * Same as [PsiElement.getContainingFile], but return a "fake" file. See [org.rust.lang.core.macros.RsExpandedElement]. */ diff --git a/src/main/kotlin/org/rust/lang/core/psi/ext/RsBlock.kt b/src/main/kotlin/org/rust/lang/core/psi/ext/RsBlock.kt new file mode 100644 index 00000000000..74df9551526 --- /dev/null +++ b/src/main/kotlin/org/rust/lang/core/psi/ext/RsBlock.kt @@ -0,0 +1,45 @@ +/* + * Use of this source code is governed by the MIT license that can be + * found in the LICENSE file. + */ + +package org.rust.lang.core.psi.ext + +import org.rust.lang.core.macros.RsExpandedElement +import org.rust.lang.core.psi.RsBlock +import org.rust.lang.core.psi.RsExpr +import org.rust.lang.core.psi.RsMacroCall +import org.rust.lang.core.psi.RsStmt + +/** + * Can contain [RsStmt]s and [RsExpr]s (which are equivalent to RsExprStmt(RsExpr)) + */ +val RsBlock.expandedStmts: List + get() { + val stmts = mutableListOf() + processExpandedStmtsInternal { stmt -> + stmts.add(stmt) + false + } + return stmts + } + +private val RsBlock.stmtsAndMacros: Sequence + get() { + val stub = greenStub + return if (stub != null) { + stub.childrenStubs.asSequence().map { it.psi } + } else { + childrenWithLeaves + }.filterIsInstance() + } + +private fun RsBlock.processExpandedStmtsInternal(processor: (RsExpandedElement) -> Boolean): Boolean { + return stmtsAndMacros.any { it.processStmt(processor) } +} + +private fun RsElement.processStmt(processor: (RsExpandedElement) -> Boolean) = when (this) { + is RsMacroCall -> processExpansionRecursively(processor) + is RsExpandedElement -> processor(this) + else -> false +} diff --git a/src/main/kotlin/org/rust/lang/core/psi/ext/RsStmt.kt b/src/main/kotlin/org/rust/lang/core/psi/ext/RsStmt.kt new file mode 100644 index 00000000000..b603825f734 --- /dev/null +++ b/src/main/kotlin/org/rust/lang/core/psi/ext/RsStmt.kt @@ -0,0 +1,15 @@ +/* + * Use of this source code is governed by the MIT license that can be + * found in the LICENSE file. + */ + +package org.rust.lang.core.psi.ext + +import com.intellij.lang.ASTNode +import com.intellij.psi.PsiElement +import org.rust.lang.core.macros.RsExpandedElement +import org.rust.lang.core.psi.RsStmt + +abstract class RsStmtMixin(node: ASTNode) : RsElementImpl(node), RsStmt { + override fun getContext(): PsiElement? = RsExpandedElement.getContextImpl(this) +} diff --git a/src/main/kotlin/org/rust/lang/core/resolve/NameResolution.kt b/src/main/kotlin/org/rust/lang/core/resolve/NameResolution.kt index 1dad8915e25..0a3958e6a68 100644 --- a/src/main/kotlin/org/rust/lang/core/resolve/NameResolution.kt +++ b/src/main/kotlin/org/rust/lang/core/resolve/NameResolution.kt @@ -17,7 +17,10 @@ import com.intellij.psi.PsiElement import com.intellij.psi.PsiFile import com.intellij.psi.StubBasedPsiElement import com.intellij.psi.stubs.StubElement -import com.intellij.psi.util.* +import com.intellij.psi.util.CachedValue +import com.intellij.psi.util.CachedValueProvider +import com.intellij.psi.util.CachedValuesManager +import com.intellij.psi.util.PsiTreeUtil import org.rust.cargo.project.workspace.CargoWorkspace import org.rust.cargo.project.workspace.PackageOrigin import org.rust.cargo.util.AutoInjectedCrates.CORE @@ -621,8 +624,9 @@ fun processLifetimeResolveVariants(lifetime: RsLifetime, processor: RsResolvePro } fun processLocalVariables(place: RsElement, processor: (RsPatBinding) -> Unit) { + val hygieneFilter = makeHygieneFilter(place) walkUp(place, { it is RsItemElement }) { cameFrom, scope -> - processLexicalDeclarations(scope, cameFrom, VALUES, ItemProcessingMode.WITH_PRIVATE_IMPORTS) { v -> + processLexicalDeclarations(scope, cameFrom, VALUES, hygieneFilter, ItemProcessingMode.WITH_PRIVATE_IMPORTS) { v -> val el = v.element if (el is RsPatBinding) processor(el) false @@ -1174,6 +1178,7 @@ private fun processLexicalDeclarations( scope: RsElement, cameFrom: PsiElement, ns: Set, + hygieneFilter: (RsPatBinding) -> Boolean, ipm: ItemProcessingMode, processor: RsResolveProcessor ): Boolean { @@ -1181,7 +1186,7 @@ private fun processLexicalDeclarations( fun processPattern(pattern: RsPat, processor: RsResolveProcessor): Boolean { val boundNames = PsiTreeUtil.findChildrenOfType(pattern, RsPatBinding::class.java) - .filter { it.reference.resolve() == null } + .filter { it.reference.resolve() == null && hygieneFilter(it) } return processAll(boundNames, processor) } @@ -1265,10 +1270,14 @@ private fun processLexicalDeclarations( } } - for (stmt in scope.stmtList.asReversed()) { - val pat = (stmt as? RsLetDecl)?.pat ?: continue - if (PsiUtilCore.compareElementsByPosition(cameFrom, stmt) < 0) continue - if (stmt == cameFrom) continue + val letDecls = mutableListOf() + for (stmt in scope.expandedStmts) { + if (cameFrom == stmt) break + if (stmt is RsLetDecl) letDecls.add(stmt) + } + + for (let in letDecls.asReversed()) { + val pat = let.pat ?: continue if (processPattern(pat, shadowingProcessor)) return true } } @@ -1316,6 +1325,11 @@ fun processNestedScopesUpwards( isCompletion: Boolean, processor: RsResolveProcessor ): Boolean { + val hygieneFilter: (RsPatBinding) -> Boolean = if (scopeStart is RsPath && ns == VALUES) { + makeHygieneFilter(scopeStart) + } else { + { true } + } val prevScope = mutableSetOf() val stop = walkUp(scopeStart, { it is RsMod }) { cameFrom, scope -> processWithShadowing(prevScope, processor) { shadowingProcessor -> @@ -1324,7 +1338,7 @@ fun processNestedScopesUpwards( isCompletion -> ItemProcessingMode.WITH_PRIVATE_IMPORTS_N_EXTERN_CRATES_COMPLETION else -> ItemProcessingMode.WITH_PRIVATE_IMPORTS_N_EXTERN_CRATES } - processLexicalDeclarations(scope, cameFrom, ns, ipm, shadowingProcessor) + processLexicalDeclarations(scope, cameFrom, ns, hygieneFilter, ipm, shadowingProcessor) } } if (stop) return true @@ -1341,6 +1355,21 @@ fun processNestedScopesUpwards( return false } +private fun makeHygieneFilter(anchor: PsiElement): (RsPatBinding) -> Boolean { + val hygienicScope = if (!anchor.isExpandedFromMacro) { + anchor.containingFile + } else { + val nameIdentifier = if (anchor is RsReferenceElement) anchor.referenceNameElement else anchor + (nameIdentifier.findElementExpandedFrom() ?: nameIdentifier).containingFile + } + + return fun(element: RsPatBinding): Boolean { + val nameIdentifier = element.nameIdentifier ?: return false + val hygienicScope2 = (nameIdentifier.findElementExpandedFrom() ?: nameIdentifier).containingFile + return hygienicScope == hygienicScope2 + } +} + inline fun processWithShadowing( prevScope: MutableSet, crossinline processor: RsResolveProcessor, @@ -1439,3 +1468,4 @@ object NameResolutionTestmarks { } private data class ImplicitStdlibCrate(val name: String, val crateRoot: RsFile) + diff --git a/src/main/kotlin/org/rust/lang/core/resolve/ref/RsMacroBodyReferenceDelegateImpl.kt b/src/main/kotlin/org/rust/lang/core/resolve/ref/RsMacroBodyReferenceDelegateImpl.kt index 22afa973235..d4837918462 100644 --- a/src/main/kotlin/org/rust/lang/core/resolve/ref/RsMacroBodyReferenceDelegateImpl.kt +++ b/src/main/kotlin/org/rust/lang/core/resolve/ref/RsMacroBodyReferenceDelegateImpl.kt @@ -10,6 +10,7 @@ import org.rust.lang.core.macros.findExpansionElements import org.rust.lang.core.psi.ext.RsElement import org.rust.lang.core.psi.ext.RsReferenceElementBase import org.rust.lang.core.psi.ext.ancestors +import org.rust.openapiext.Testmark class RsMacroBodyReferenceDelegateImpl( element: RsReferenceElementBase @@ -19,6 +20,7 @@ class RsMacroBodyReferenceDelegateImpl( private val delegates: List get() { + Testmarks.touched.hit() return element.findExpansionElements()?.mapNotNull { delegated -> delegated.ancestors .mapNotNull { it.reference } @@ -32,4 +34,8 @@ class RsMacroBodyReferenceDelegateImpl( override fun multiResolve(): List = delegates.flatMap { it.multiResolve() }.distinct() + + object Testmarks { + val touched = Testmark("touched") + } } diff --git a/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt b/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt index 21b2baa08b1..0292aa78bad 100644 --- a/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt +++ b/src/main/kotlin/org/rust/lang/core/types/infer/TypeInferenceWalker.kt @@ -9,6 +9,7 @@ import com.intellij.openapi.progress.ProgressManager import com.intellij.psi.PsiElement import com.intellij.util.containers.isNullOrEmpty import org.rust.lang.core.macros.MacroExpansion +import org.rust.lang.core.macros.expandedFromSequence import org.rust.lang.core.psi.* import org.rust.lang.core.psi.ext.* import org.rust.lang.core.resolve.* @@ -94,10 +95,25 @@ class RsTypeInferenceWalker( private fun RsBlock.inferType(expected: Ty? = null, coerce: Boolean = false): Ty { var isDiverging = false - for (stmt in stmtList) { - isDiverging = processStatement(stmt) || isDiverging + val expandedStmts = expandedStmts + val tailExpr = expandedStmts.lastOrNull() + ?.let { it as? RsExpr } + ?.takeIf { e -> + e.expandedFromSequence.all { + val bracesKind = it.bracesKind ?: return@all false + !bracesKind.needsSemicolon || it.semicolon == null + } + } + for (stmt in expandedStmts) { + val result = when (stmt) { + tailExpr -> false + is RsStmt -> processStatement(stmt) + is RsExpr -> stmt.inferType() == TyNever + else -> false + } + isDiverging = result || isDiverging } - val type = (if (coerce) expr?.inferTypeCoercableTo(expected!!) else expr?.inferType(expected)) ?: TyUnit + val type = (if (coerce) tailExpr?.inferTypeCoercableTo(expected!!) else tailExpr?.inferType(expected)) ?: TyUnit return if (isDiverging) TyNever else type } @@ -1134,7 +1150,7 @@ class RsTypeInferenceWalker( name == "stringify" -> TyReference(TyStr, Mutability.IMMUTABLE) name == "module_path" -> TyReference(TyStr, Mutability.IMMUTABLE) name == "cfg" -> TyBool - else -> TyUnknown + else -> (macroCall.expansion as? MacroExpansion.Expr)?.expr?.inferType() ?: TyUnknown } } diff --git a/src/main/kotlin/org/rust/openapiext/psi.kt b/src/main/kotlin/org/rust/openapiext/psi.kt index 3d725bb3bd3..f0f02afb206 100644 --- a/src/main/kotlin/org/rust/openapiext/psi.kt +++ b/src/main/kotlin/org/rust/openapiext/psi.kt @@ -5,8 +5,15 @@ package org.rust.openapiext +import com.intellij.psi.PsiCompiledElement import com.intellij.psi.PsiElement import com.intellij.psi.impl.source.tree.CompositeElement +import com.intellij.psi.search.PsiElementProcessor +import com.intellij.psi.util.PsiTreeUtil +import com.intellij.util.containers.ContainerUtil +import org.rust.lang.core.psi.RsMacroCall +import org.rust.lang.core.psi.RsWithMacroExpansionsRecursiveElementWalkingVisitor +import org.rust.lang.core.psi.ext.expansion /** @@ -22,3 +29,56 @@ inline fun PsiElement.forEachChild(action: (PsiElement) -> Unit) { psiChild = psiChild.nextSibling } } + +/** Behaves like [PsiTreeUtil.findChildOfAnyType], but also collects elements expanded from macros */ +fun findDescendantsWithMacrosOfAnyType( + element: PsiElement?, + strict: Boolean, + vararg classes: Class +): Collection { + if (element == null) return ContainerUtil.emptyList() + + val processor = object : PsiElementProcessor.CollectElements() { + override fun execute(each: PsiElement): Boolean { + if (strict && each === element) return true + return if (PsiTreeUtil.instanceOf(each, *classes)) { + super.execute(each) + } else true + } + } + processElementsWithMacros(element, processor) + @Suppress("UNCHECKED_CAST") + return processor.collection as Collection +} + +/** Behaves like [PsiTreeUtil.processElements], but also collects elements expanded from macros */ +fun processElementsWithMacros(element: PsiElement, processor: PsiElementProcessor): Boolean { + if (element is PsiCompiledElement || !element.isPhysical) { + // DummyHolders cannot be visited by walking visitors because children/parent relationship is broken there + if (!processor.execute(element)) return false + for (child in element.children) { + if (child is RsMacroCall && child.macroArgument != null) { + child.expansion?.elements?.forEach { + if (!processElementsWithMacros(it, processor)) return false + } + } else if (!processElementsWithMacros(child, processor)) { + return false + } + } + return true + } + + var result = true + element.accept(object : RsWithMacroExpansionsRecursiveElementWalkingVisitor() { + override fun visitElement(element: PsiElement) { + if (processor.execute(element)) { + super.visitElement(element) + } else { + stopWalking() + result = false + } + } + }) + + return result +} diff --git a/src/test/kotlin/org/rust/RsTestBase.kt b/src/test/kotlin/org/rust/RsTestBase.kt index 2b2902338b3..7a82f99af2c 100644 --- a/src/test/kotlin/org/rust/RsTestBase.kt +++ b/src/test/kotlin/org/rust/RsTestBase.kt @@ -16,6 +16,7 @@ import com.intellij.openapi.vfs.VirtualFile import com.intellij.openapi.vfs.VirtualFileFilter import com.intellij.psi.PsiElement import com.intellij.psi.impl.PsiManagerEx +import com.intellij.psi.util.PsiTreeUtil import com.intellij.testFramework.LightProjectDescriptor import com.intellij.testFramework.PlatformTestUtil import com.intellij.testFramework.UsefulTestCase @@ -28,8 +29,9 @@ import org.rust.cargo.project.model.cargoProjects import org.rust.cargo.project.workspace.CargoWorkspace import org.rust.cargo.toolchain.RustChannel import org.rust.cargo.toolchain.RustcVersion +import org.rust.lang.core.macros.findExpansionElements import org.rust.lang.core.macros.macroExpansionManager -import org.rust.lang.core.psi.ext.ancestorOrSelf +import org.rust.lang.core.psi.ext.startOffset import org.rust.openapiext.saveAllDocuments import org.rust.stdext.BothEditions @@ -206,8 +208,11 @@ abstract class RsTestBase : LightPlatformCodeInsightFixtureTestCase(), RsTestCas } } - protected inline fun findElementInEditor(marker: String = "^"): T { - val (element, data) = findElementWithDataAndOffsetInEditor(marker) + protected inline fun findElementInEditor(marker: String = "^"): T = + findElementInEditor(T::class.java, marker) + + protected fun findElementInEditor(psiClass: Class, marker: String): T { + val (element, data) = findElementWithDataAndOffsetInEditor(psiClass, marker) check(data.isEmpty()) { "Did not expect marker data" } return element } @@ -220,7 +225,14 @@ abstract class RsTestBase : LightPlatformCodeInsightFixtureTestCase(), RsTestCas protected inline fun findElementWithDataAndOffsetInEditor( marker: String = "^" ): Triple { - val elementsWithDataAndOffset = findElementsWithDataAndOffsetInEditor(marker) + return findElementWithDataAndOffsetInEditor(T::class.java, marker) + } + + protected fun findElementWithDataAndOffsetInEditor( + psiClass: Class, + marker: String + ): Triple { + val elementsWithDataAndOffset = findElementsWithDataAndOffsetInEditor(psiClass, marker) check(elementsWithDataAndOffset.isNotEmpty()) { "No `$marker` marker:\n${myFixture.file.text}" } check(elementsWithDataAndOffset.size <= 1) { "More than one `$marker` marker:\n${myFixture.file.text}" } return elementsWithDataAndOffset.first() @@ -228,6 +240,13 @@ abstract class RsTestBase : LightPlatformCodeInsightFixtureTestCase(), RsTestCas protected inline fun findElementsWithDataAndOffsetInEditor( marker: String = "^" + ): List> { + return findElementsWithDataAndOffsetInEditor(T::class.java, marker) + } + + protected fun findElementsWithDataAndOffsetInEditor( + psiClass: Class, + marker: String ): List> { val commentPrefix = LanguageCommenters.INSTANCE.forLanguage(myFixture.file.language).lineCommentPrefix ?: "//" val caretMarker = "$commentPrefix$marker" @@ -242,14 +261,25 @@ abstract class RsTestBase : LightPlatformCodeInsightFixtureTestCase(), RsTestCas val previousLine = LogicalPosition(markerPosition.line - 1, markerPosition.column) val elementOffset = myFixture.editor.logicalPositionToOffset(previousLine) val elementAtMarker = myFixture.file.findElementAt(elementOffset)!! - val element = elementAtMarker.ancestorOrSelf() + + if (followMacroExpansions) { + val expandedElementAtMarker = elementAtMarker.findExpansionElements()?.singleOrNull() + val expandedElement = expandedElementAtMarker?.let { PsiTreeUtil.getParentOfType(it, psiClass, false) } + if (expandedElement != null) { + val offset = expandedElementAtMarker.startOffset + (elementOffset - elementAtMarker.startOffset) + result.add(Triple(expandedElement, data, offset)) + continue + } + } + + val element = PsiTreeUtil.getParentOfType(elementAtMarker, psiClass, false) if (element != null) { result.add(Triple(element, data, elementOffset)) } else { val injectionElement = InjectedLanguageManager.getInstance(project) .findInjectedElementAt(myFixture.file, elementOffset) - ?.ancestorOrSelf() - ?: error("No ${T::class.java.simpleName} at ${elementAtMarker.text}") + ?.let { PsiTreeUtil.getParentOfType(it, psiClass, false) } + ?: error("No ${psiClass.simpleName} at ${elementAtMarker.text}") val injectionOffset = (injectionElement.containingFile.virtualFile as VirtualFileWindow) .documentWindow.hostToInjected(elementOffset) result.add(Triple(injectionElement, data, injectionOffset)) @@ -258,6 +288,8 @@ abstract class RsTestBase : LightPlatformCodeInsightFixtureTestCase(), RsTestCas return result } + protected open val followMacroExpansions: Boolean get() = false + protected fun replaceCaretMarker(text: String) = text.replace("/*caret*/", "") protected fun reportTeamCityMetric(name: String, value: Long) { diff --git a/src/test/kotlin/org/rust/lang/core/macros/RsMacroCallReferenceDelegationTest.kt b/src/test/kotlin/org/rust/lang/core/macros/RsMacroCallReferenceDelegationTest.kt index 7b02932c338..e9622500dbe 100644 --- a/src/test/kotlin/org/rust/lang/core/macros/RsMacroCallReferenceDelegationTest.kt +++ b/src/test/kotlin/org/rust/lang/core/macros/RsMacroCallReferenceDelegationTest.kt @@ -7,6 +7,7 @@ package org.rust.lang.core.macros import org.rust.ExpandMacros import org.rust.lang.core.resolve.RsResolveTestBase +import org.rust.lang.core.resolve.ref.RsMacroBodyReferenceDelegateImpl.Testmarks @ExpandMacros class RsMacroCallReferenceDelegationTest : RsResolveTestBase() { @@ -17,7 +18,7 @@ class RsMacroCallReferenceDelegationTest : RsResolveTestBase() { foo! { type T = X; } //^ - """) + """, Testmarks.touched) fun `test statement context`() = checkByCode(""" struct X; @@ -28,19 +29,16 @@ class RsMacroCallReferenceDelegationTest : RsResolveTestBase() { type T = X; }; //^ } - """) + """, Testmarks.touched) - // TODO adjust type inference to take into account macros - fun `test expression context`() = expect { - checkByCode(""" + fun `test expression context`() = checkByCode(""" struct X; //X macro_rules! foo { ($($ i:tt)*) => { $( $ i )* }; } fn main () { let a = foo!(X); } //^ - """) - } + """, Testmarks.touched) fun `test type context`() = checkByCode(""" struct X; @@ -48,7 +46,7 @@ class RsMacroCallReferenceDelegationTest : RsResolveTestBase() { macro_rules! foo { ($($ i:tt)*) => { $( $ i )* }; } type T = foo!(X); //^ - """) + """, Testmarks.touched) // TODO implement `getContext()` in all RsPat PSI elements fun `test pattern context`() = expect { @@ -63,7 +61,7 @@ class RsMacroCallReferenceDelegationTest : RsResolveTestBase() { _ => {} } } - """) + """, Testmarks.touched) } fun `test lifetime`() = checkByCode(""" @@ -77,7 +75,7 @@ class RsMacroCallReferenceDelegationTest : RsResolveTestBase() { fn foo(&self) -> &'a u8 {} } //^ } - """) + """, Testmarks.touched) fun `test 2-segment path 1`() = checkByCode(""" mod foo { @@ -88,7 +86,7 @@ class RsMacroCallReferenceDelegationTest : RsResolveTestBase() { foo! { type T = foo::X; } //^ - """) + """, Testmarks.touched) fun `test 2-segment path 2`() = checkByCode(""" mod foo { @@ -98,5 +96,5 @@ class RsMacroCallReferenceDelegationTest : RsResolveTestBase() { foo! { type T = foo::X; } //^ - """) + """, Testmarks.touched) } diff --git a/src/test/kotlin/org/rust/lang/core/macros/RsMacroExpansionRangeMappingTest.kt b/src/test/kotlin/org/rust/lang/core/macros/RsMacroExpansionRangeMappingTest.kt index c3c26fda20c..2bce8ff8ecc 100644 --- a/src/test/kotlin/org/rust/lang/core/macros/RsMacroExpansionRangeMappingTest.kt +++ b/src/test/kotlin/org/rust/lang/core/macros/RsMacroExpansionRangeMappingTest.kt @@ -149,7 +149,7 @@ class RsMacroExpansionRangeMappingTest : RsTestBase() { val resolved = ref.reference.resolve() ?: error("Failed to resolve ${ref.text}") val elementInExpansion = refiner(resolved) check(elementInExpansion.isExpandedFromMacro) { "Must resolve to macro expansion" } - val elementInCallBody = elementInExpansion.findElementExpandedFrom() + val elementInCallBody = elementInExpansion.findElementExpandedFromChecked() ?: error("Failed to find element expanded from") assertEquals(myFixture.editor.caretModel.currentCaret.offset, elementInCallBody.startOffset) @@ -168,7 +168,7 @@ class RsMacroExpansionRangeMappingTest : RsTestBase() { val resolved = ref.reference.resolve() ?: error("Failed to resolve ${ref.text}") val elementInExpansion = refiner(resolved) check(elementInExpansion.isExpandedFromMacro) { "Must resolve to macro expansion" } - val elementInCallBody = elementInExpansion.findElementExpandedFrom() + val elementInCallBody = elementInExpansion.findElementExpandedFromChecked() assertNull(elementInCallBody) } diff --git a/src/test/kotlin/org/rust/lang/core/resolve/RsMacroExpansionResolveTest.kt b/src/test/kotlin/org/rust/lang/core/resolve/RsMacroExpansionResolveTest.kt index 27941f20359..19fb899e642 100644 --- a/src/test/kotlin/org/rust/lang/core/resolve/RsMacroExpansionResolveTest.kt +++ b/src/test/kotlin/org/rust/lang/core/resolve/RsMacroExpansionResolveTest.kt @@ -11,6 +11,8 @@ import org.rust.WithDependencyRustProjectDescriptor @ExpandMacros class RsMacroExpansionResolveTest : RsResolveTestBase() { + override val followMacroExpansions: Boolean get() = true + fun `test expand item`() = checkByCode(""" macro_rules! if_std { ($ i:item) => ( @@ -423,4 +425,88 @@ class RsMacroExpansionResolveTest : RsResolveTestBase() { foo().bar(); } //^ """) + + fun `test resolve binding from stmt context macro`() = checkByCode(""" + macro_rules! foo { + ($ i:stmt) => ( $ i ) + } + fn main() { + foo! { + let a = 0; + } //X + let _ = a; + } //^ + """) + + fun `test hygiene 1`() = checkByCode(""" + macro_rules! foo { + () => ( let a = 0; ) + } + fn main() { + foo!(); + let _ = a; + } //^ unresolved + """) + + fun `test hygiene 2`() = checkByCode(""" + macro_rules! foo { + ($ i:ident) => ( let $ i = 0; ) + } + macro_rules! bar { + ($($ t:tt)*) => { $($ t)* }; + } + fn main() { + bar! { + foo!(a); + //X + let _ = a; + } //^ + } + """) + + fun `test hygiene 3`() = checkByCode(""" + macro_rules! foo { + () => ( let a = 0; ) + } + macro_rules! bar { + ($($ t:tt)*) => { $($ t)* }; + } + fn main() { + bar! { + foo!(); + let _ = a; + } //^ unresolved + } + """) + + fun `test hygiene 4`() = checkByCode(""" + macro_rules! foo { + () => ( let a = 0; ) + } + macro_rules! bar { + ($($ t:tt)*) => { $($ t)* }; + } + fn main() { + let a = 1; + //X + bar! { + foo!(); + let _ = a; + } //^ + } + """) + + fun `test hygiene 5`() = checkByCode(""" + macro_rules! bar { + ($($ t:tt)*) => { $($ t)* }; + } + fn main() { + let a = 1; + //X + bar! { + let _ = a; + } //^ + let a = 2; + } + """) } diff --git a/src/test/kotlin/org/rust/lang/core/resolve/RsResolveTestBase.kt b/src/test/kotlin/org/rust/lang/core/resolve/RsResolveTestBase.kt index e45594c0849..a9582184498 100644 --- a/src/test/kotlin/org/rust/lang/core/resolve/RsResolveTestBase.kt +++ b/src/test/kotlin/org/rust/lang/core/resolve/RsResolveTestBase.kt @@ -29,6 +29,12 @@ abstract class RsResolveTestBase : RsTestBase() { protected inline fun checkByCodeGeneric( @Language("Rust") code: String, fileName: String = "main.rs" + ) = checkByCodeGeneric(T::class.java, code, fileName) + + protected fun checkByCodeGeneric( + targetPsiClass: Class, + @Language("Rust") code: String, + fileName: String = "main.rs" ) { InlineFile(code, fileName) @@ -43,7 +49,7 @@ abstract class RsResolveTestBase : RsTestBase() { } val resolved = refElement.checkedResolve(offset) - val target = findElementInEditor("X") + val target = findElementInEditor(targetPsiClass, "X") check(resolved == target) { "$refElement `${refElement.text}` should resolve to $target, was $resolved instead" diff --git a/src/test/kotlin/org/rust/lang/core/type/RsMacroTypeInferenceTest.kt b/src/test/kotlin/org/rust/lang/core/type/RsMacroTypeInferenceTest.kt new file mode 100644 index 00000000000..f5a2a2cbd9f --- /dev/null +++ b/src/test/kotlin/org/rust/lang/core/type/RsMacroTypeInferenceTest.kt @@ -0,0 +1,116 @@ +/* + * Use of this source code is governed by the MIT license that can be + * found in the LICENSE file. + */ + +package org.rust.lang.core.type + +class RsMacroTypeInferenceTest : RsTypificationTestBase() { + override val followMacroExpansions: Boolean get() = true + + fun `test let stmt expanded from a macro`() = testExpr(""" + macro_rules! foo { ($ s:stmt) => { $ s }; } + fn main() { + foo! { + let a = 0; + } + a; + } //^ i32 + """) + + fun `test let stmt (no semicolon) expanded from a macro`() = testExpr(""" + macro_rules! foo { ($ s:stmt) => { $ s }; } + fn main() { + foo! { + let a = 0 + } + a; + } //^ i32 + """) + + fun `test expr expanded from a macro`() = testExpr(""" + macro_rules! foo { ($ e:expr) => { $ e }; } + fn main() { + let a = foo!(0u16); + a; + } //^ u16 + """) + + fun `test stmt-context macro with an expression is typified 1`() = testExpr(""" + macro_rules! foo { ($ e:expr) => { $ e }; } + fn main() { + foo!(2 + 2); + //^ i32 + foobar; + } + """) + + fun `test stmt-context macro with an expression is typified 2`() = testExpr(""" + macro_rules! foo { ($ e:expr) => { 0; $ e }; } + fn main() { + foo!(2 + 2); + //^ i32 + foobar; + } + """) + + fun `test tail expr 1`() = testExpr(""" + macro_rules! foo { ($ s:stmt) => { 0u8; $ s }; } + fn main() { + let a = { foo! { 1u16 } }; + a; + } //^ u16 + """) + + fun `test tail expr 2`() = testExpr(""" + macro_rules! foo { ($ s:stmt) => { 0u8; $ s }; } + fn main() { + let a = { foo!(1u16); }; + a; + } //^ () + """) + + fun `test tail expr 3`() = testExpr(""" + macro_rules! foo { ($ s:stmt) => { 0u8; $ s }; } + fn main() { + let a = { foo!(foo!(1u16)); }; + a; + } //^ () + """) + + // TODO looks like there are needed changes in tail expr grammar + fun `test tail expr 4`() = expect { + testExpr (""" + macro_rules! foo { ($ s:stmt) => { 0u8; $ s }; } + fn main() { + let a = { foo! { foo!(1u16) } }; + a; + } //^ u16 + """) + } + + fun `test tail expr 5`() = testExpr(""" + macro_rules! foo { ($($ t:tt)*) => { 0u8; $($ t)* }; } + fn main() { + let a = { foo! { foo!(1u16); } }; + a; + } //^ () + """) + + fun `test tail expr 6`() = testExpr(""" + macro_rules! foo { ($ s:expr) => { $ s }; } + fn main() { + let a = { foo!(1u16) }; + a; + } //^ u16 + """) + + fun `test unification`() = testExpr(""" + macro_rules! foo { ($ s:stmt) => { $ s }; } + fn main() { + let a = 0; + foo! { a += 1u8 } + a; + } //^ u8 + """) +} diff --git a/src/test/kotlin/org/rust/lang/core/type/RsTypificationTestBase.kt b/src/test/kotlin/org/rust/lang/core/type/RsTypificationTestBase.kt index 827f2205328..432755b432f 100644 --- a/src/test/kotlin/org/rust/lang/core/type/RsTypificationTestBase.kt +++ b/src/test/kotlin/org/rust/lang/core/type/RsTypificationTestBase.kt @@ -9,9 +9,9 @@ import com.intellij.openapi.vfs.VirtualFileFilter import org.intellij.lang.annotations.Language import org.rust.RsTestBase import org.rust.fileTreeFromText +import org.rust.lang.core.macros.expandedFrom import org.rust.lang.core.psi.RsExpr -import org.rust.lang.core.psi.ext.RsInferenceContextOwner -import org.rust.lang.core.psi.ext.descendantsOfType +import org.rust.lang.core.psi.ext.* import org.rust.lang.core.types.inference import org.rust.lang.core.types.type import org.rust.lang.utils.Severity @@ -70,9 +70,9 @@ abstract class RsTypificationTestBase : RsTestBase() { } private fun checkAllExpressionsTypified() { - val notTypifiedExprs = myFixture.file.descendantsOfType().filter { expr -> + val notTypifiedExprs = myFixture.file.descendantsWithMacrosOfType().filter { expr -> expr.inference?.isExprTypeInferred(expr) == false - } + }.filter { it.expandedFrom?.resolveToMacro()?.isRustcDocOnlyMacro != true } if (notTypifiedExprs.isNotEmpty()) { error( notTypifiedExprs.joinToString(