Skip to content

Commit

Permalink
IDE: add imports after pasting code
Browse files Browse the repository at this point in the history
  • Loading branch information
Kobzol committed Aug 3, 2022
1 parent 39122de commit 583fc92
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 27 deletions.
Expand Up @@ -20,17 +20,32 @@ import com.intellij.psi.PsiFile
import com.intellij.psi.impl.source.tree.injected.changesHandler.range
import com.intellij.psi.util.parents
import org.rust.ide.inspections.import.AutoImportFix
import org.rust.ide.utils.import.ImportCandidate
import org.rust.ide.utils.import.ImportCandidatesCollector2
import org.rust.ide.utils.import.ImportContext2
import org.rust.ide.utils.import.import
import org.rust.lang.core.psi.*
import org.rust.lang.core.psi.ext.*
import org.rust.lang.core.resolve.VALUES
import org.rust.lang.core.types.inference
import org.rust.openapiext.toPsiFile
import java.awt.datatransfer.DataFlavor
import java.awt.datatransfer.Transferable
import java.awt.datatransfer.UnsupportedFlavorException
import java.io.IOException

class RsTextBlockTransferableData(val offsetToFqnMap: Map<Int, String>) : TextBlockTransferableData, Cloneable {
data class ImportMap(private val offsetToFqnMap: Map<Int, String>) {
fun elementToFqn(element: PsiElement, range: TextRange): String? {
val offset = toRelativeOffset(element, range)
return offsetToFqnMap[offset]
}

companion object {
val Empty: ImportMap = ImportMap(emptyMap())
}
}

class RsTextBlockTransferableData(val importMap: ImportMap) : TextBlockTransferableData, Cloneable {
override fun getFlavor(): DataFlavor? = RsImportCopyPasteProcessor.dataFlavor

override fun getOffsetCount(): Int = 0
Expand Down Expand Up @@ -61,7 +76,7 @@ class RsImportCopyPasteProcessor : CopyPastePostProcessor<RsTextBlockTransferabl
val map = if (ranges.size == 1) {
createFqnMap(file, ranges[0])
} else {
emptyMap()
ImportMap.Empty
}

return listOf(
Expand Down Expand Up @@ -97,12 +112,17 @@ class RsImportCopyPasteProcessor : CopyPastePostProcessor<RsTextBlockTransferabl
val elements = gatherElements(rsFile, range)
val importCtx = elements.firstOrNull { it is RsElement } as? RsElement ?: return

val visitor = ImportingVisitor(project, importCtx, range, data)
val visitor = ImportingVisitor(importCtx, range, data)

runWriteAction {
for (element in elements) {
element.accept(visitor)
}
// We need to import the candidates after visiting all elements, otherwise the relative offsets could be
// invalidated after an import has been added
for ((candidate, ctx) in visitor.importCandidates) {
candidate.import(ctx)
}
}
}

Expand All @@ -127,34 +147,50 @@ class RsImportCopyPasteProcessor : CopyPastePostProcessor<RsTextBlockTransferabl
private class RsReferenceData

private class ImportingVisitor(
private val project: Project,
private val importCtx: RsElement,
private val range: TextRange,
private val data: RsTextBlockTransferableData
) : RsRecursiveVisitor() {
private val candidates: MutableList<Pair<ImportCandidate, RsElement>> = mutableListOf()

val importCandidates: List<Pair<ImportCandidate, RsElement>> = candidates

override fun visitPath(path: RsPath) {
val ctx = AutoImportFix.findApplicableContext(project, path)
val ctx = AutoImportFix.findApplicableContext(path)
handleContext(path, ctx)
super.visitPath(path)
}

override fun visitMethodCall(methodCall: RsMethodCall) {
val ctx = AutoImportFix.findApplicableContext(project, methodCall)
val ctx = AutoImportFix.findApplicableContext(methodCall)
handleContext(methodCall, ctx)
super.visitMethodCall(methodCall)
}

override fun visitPatBinding(binding: RsPatBinding) {
if (data.importMap.elementToFqn(binding, range) != null) {
val referenceName = binding.referenceName
val isNameInScope = binding.hasInScope(referenceName, VALUES)
if (!isNameInScope) {
val importContext = ImportContext2.from(binding, ImportContext2.Type.AUTO_IMPORT)
if (importContext != null) {
val candidates = ImportCandidatesCollector2.getImportCandidates(importContext, referenceName)
val ctx = AutoImportFix.Context(AutoImportFix.Type.GENERAL_PATH, candidates)
handleContext(binding, ctx)
}
}
}
super.visitPatBinding(binding)
}

private fun handleContext(element: PsiElement, ctx: AutoImportFix.Context?) {
if (ctx != null) {
if (ctx.candidates.size == 1) {
ctx.candidates[0].import(importCtx)
} else {
val candidate = ctx.candidates.find {
val offset = toRelativeOffset(element, range)
val fqn = data.offsetToFqnMap[offset]
fqn == it.qualifiedNamedItem.item.qualifiedName
}
candidate?.import(importCtx)
val candidate = ctx.candidates.find {
val fqn = data.importMap.elementToFqn(element, range)
fqn == it.qualifiedNamedItem.item.qualifiedName
}
if (candidate != null) {
candidates.add(candidate to importCtx)
}
}
}
Expand All @@ -164,38 +200,46 @@ private class ImportingVisitor(
* Records mapping between offsets (relative to copy/paste content range) and fully qualified names of resolved items
* from paths and method calls.
*/
private fun createFqnMap(file: RsFile, range: TextRange): Map<Int, String> {
private fun createFqnMap(file: RsFile, range: TextRange): ImportMap {
val elements = gatherElements(file, range)
val map = mutableMapOf<Int, String>()
val fqnMap = mutableMapOf<Int, String>()

val visitor = object : RsRecursiveVisitor() {
override fun visitPath(path: RsPath) {
val target = (path.reference?.resolve() as? RsQualifiedNamedElement)?.qualifiedName
if (target != null) {
map[toRelativeOffset(path, range)] = target
fqnMap[toRelativeOffset(path, range)] = target
}

super.visitPath(path)
}

override fun visitMethodCall(methodCall: RsMethodCall) {
val methods = methodCall.inference?.getResolvedMethod(methodCall)
val target = methods?.mapNotNull {
val target = methods?.firstNotNullOfOrNull {
it.source.implementedTrait?.element?.qualifiedName
}?.firstOrNull()
}

if (target != null) {
map[toRelativeOffset(methodCall, range)] = target
fqnMap[toRelativeOffset(methodCall, range)] = target
}

super.visitMethodCall(methodCall)
}

override fun visitPatBinding(binding: RsPatBinding) {
val target = (binding.reference.resolve() as? RsQualifiedNamedElement)?.qualifiedName
if (target != null) {
fqnMap[toRelativeOffset(binding, range)] = target
}
super.visitPatBinding(binding)
}
}
for (element in elements) {
element.accept(visitor)
}

return map
return ImportMap(fqnMap)
}

private fun gatherElements(file: RsFile, range: TextRange): List<PsiElement> {
Expand Down
Expand Up @@ -6,14 +6,9 @@
package org.rust.ide.typing.paste

import com.intellij.openapi.actionSystem.IdeActions
import com.intellij.util.ui.UIUtil
import org.intellij.lang.annotations.Language
import org.rust.MockEdition
import org.rust.cargo.project.workspace.CargoWorkspace
import org.rust.fileTreeFromText
import org.rust.openapiext.saveAllDocuments

@MockEdition(CargoWorkspace.Edition.EDITION_2018)
class RsAddImportOnCopyPasteTest : RsCopyPasteTestBase() {
fun `test type reference`() = doCopyPasteTest("""
//- lib.rs
Expand Down Expand Up @@ -433,6 +428,34 @@ class RsAddImportOnCopyPasteTest : RsCopyPasteTestBase() {
}
""")

fun `test do not import private item`() = doCopyPasteTest("""
//- lib.rs
mod a {
pub struct S;
pub fn foo1() -> S { S }
}
mod b {
struct S;
<selection>fn foo2() -> S { S }</selection>
}
/*caret*/
""", """
//- lib.rs
mod a {
pub struct S;
pub fn foo1() -> S { S }
}
mod b {
struct S;
fn foo2() -> S { S }
}
fn foo2() -> S { S }
""")

fun `test copy paste same location`() = doCopyPasteTest("""
//- lib.rs
mod foo {
Expand Down Expand Up @@ -461,6 +484,96 @@ class RsAddImportOnCopyPasteTest : RsCopyPasteTestBase() {
}
""")

fun `test pat qualified path`() = doCopyPasteTest("""
//- lib.rs
mod foo {
pub mod consts {
pub const CONST: u32 = 1;
}
fn bar(a: u32) -> u32 {
<selection>match a {
consts::CONST => 1,
_ => 2
}</selection>
}
}
mod baz {
fn bar(a: u32) -> u32 {
/*caret*/
}
}
""", """
//- lib.rs
mod foo {
pub mod consts {
pub const CONST: u32 = 1;
}
fn bar(a: u32) -> u32 {
match a {
consts::CONST => 1,
_ => 2
}
}
}
mod baz {
use crate::foo::consts;
fn bar(a: u32) -> u32 {
match a {
consts::CONST => 1,
_ => 2
}
}
}
""")

fun `test pat binding`() = doCopyPasteTest("""
//- lib.rs
mod foo {
pub const CONST: u32 = 1;
fn bar(a: u32) -> u32 {
<selection>match a {
CONST => 1,
_ => 2
}</selection>
}
}
mod baz {
fn bar(a: u32) -> u32 {
/*caret*/
}
}
""", """
//- lib.rs
mod foo {
pub const CONST: u32 = 1;
fn bar(a: u32) -> u32 {
match a {
CONST => 1,
_ => 2
}
}
}
mod baz {
use crate::foo::CONST;
fn bar(a: u32) -> u32 {
match a {
CONST => 1,
_ => 2
}
}
}
""")

private fun doCopyPasteTest(
@Language("Rust") before: String,
@Language("Rust") after: String
Expand Down

0 comments on commit 583fc92

Please sign in to comment.