Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve completion of if and match keywords #10744

Merged
merged 3 commits into from Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -53,26 +53,29 @@ class MatchPostfixTemplate(provider: RsPostfixTemplateProvider) :
PsiDocumentManager.getInstance(project).doPostponedOperationsAndUnblockDocument(editor.document)

val matchBody = matchExpr.matchBody ?: return
val arm = matchBody.matchArmList.firstOrNull() ?: return
val blockExpr = arm.expr as? RsBlockExpr ?: return
val toBeReplaced = processor.getElementsToReplace(matchBody)

if (toBeReplaced.isEmpty()) {
moveCaretToMatchArmBlock(editor, blockExpr)
} else {
val blockExprPointer = blockExpr.createSmartPointer()
editor.buildAndRunTemplate(matchBody, toBeReplaced) {
val restoredBlockExpr = blockExprPointer.element ?: return@buildAndRunTemplate
moveCaretToMatchArmBlock(editor, restoredBlockExpr)
}
}
editor.fillArmsPlaceholders(toBeReplaced, matchExpr)
}
}

private fun moveCaretToMatchArmBlock(editor: Editor, blockExpr: RsBlockExpr) {
editor.caretModel.moveToOffset(blockExpr.block.lbrace.textOffset + 1)
private fun Editor.fillArmsPlaceholders(elementsToReplace: Collection<RsElement>, match: RsMatchExpr) {
val firstArmBlock = match.matchBody?.matchArmList?.firstOrNull()?.expr as? RsBlockExpr ?: return
if (elementsToReplace.isEmpty()) {
moveCaretToMatchArmBlock(this, firstArmBlock)
} else {
val firstArmBlockPointer = firstArmBlock.createSmartPointer()
buildAndRunTemplate(match, elementsToReplace) {
val restored = firstArmBlockPointer.element ?: return@buildAndRunTemplate
moveCaretToMatchArmBlock(this, restored)
}
}
}

private fun moveCaretToMatchArmBlock(editor: Editor, blockExpr: RsBlockExpr) {
editor.caretModel.moveToOffset(blockExpr.block.lbrace.textOffset + 1)
}

private fun getMatchProcessor(ty: Ty, context: RsElement): MatchProcessor {
return when {
ty is TyAdt && ty.item == context.knownItems.String -> StringMatchProcessor
Expand Down Expand Up @@ -101,7 +104,7 @@ private object GenericMatchProcessor : MatchProcessor() {
} else {
AddRemainingArmsFix(matchExpr, patterns)
}
fix.invoke(matchExpr.project, matchExpr.containingFile, matchExpr, matchExpr)
fix.invoke(matchExpr.project, editor = null, matchExpr)
}

override fun getElementsToReplace(matchBody: RsMatchBody): List<RsElement> =
Expand All @@ -123,3 +126,9 @@ private open class StringLikeMatchProcessor : MatchProcessor() {
private object StringMatchProcessor : StringLikeMatchProcessor() {
override fun expressionToText(expression: RsExpr): String = "${expression.text}.as_str()"
}

fun fillMatchArms(match: RsMatchExpr, editor: Editor) {
GenericMatchProcessor.normalizeMatch(match)
val elementsToReplace = match.descendantsOfType<RsPatWild>()
editor.fillArmsPlaceholders(elementsToReplace, match)
}
Expand Up @@ -8,6 +8,9 @@ package org.rust.lang.core.completion
import com.intellij.codeInsight.completion.*
import com.intellij.codeInsight.completion.ml.MLRankingIgnorable
import com.intellij.codeInsight.lookup.LookupElementBuilder
import com.intellij.codeInsight.template.impl.MacroCallNode
import com.intellij.codeInsight.template.macro.CompleteMacro
import com.intellij.openapi.application.runWriteAction
import com.intellij.openapi.components.service
import com.intellij.openapi.editor.EditorModificationUtil
import com.intellij.openapi.project.DumbAware
Expand All @@ -19,6 +22,8 @@ import com.intellij.patterns.StandardPatterns.or
import com.intellij.psi.*
import com.intellij.psi.tree.TokenSet
import com.intellij.util.ProcessingContext
import org.rust.ide.template.postfix.fillMatchArms
import org.rust.ide.utils.template.newTemplateBuilder
import org.rust.lang.core.*
import org.rust.lang.core.RsPsiPattern.baseDeclarationPattern
import org.rust.lang.core.RsPsiPattern.baseInherentImplDeclarationPattern
Expand All @@ -28,6 +33,8 @@ import org.rust.lang.core.completion.RsLookupElementProperties.KeywordKind
import org.rust.lang.core.psi.*
import org.rust.lang.core.psi.RsElementTypes.*
import org.rust.lang.core.psi.ext.*
import org.rust.openapiext.createSmartPointer
import org.rust.openapiext.moveCaretToOffset

/**
* Completes Rust keywords
Expand Down Expand Up @@ -100,21 +107,45 @@ class RsKeywordCompletionContributor : CompletionContributor(), DumbAware {
super.fillCompletionVariants(parameters, RsCompletionContributor.withRustSorter(parameters, result))
}

private fun conditionLookupElement(lookupString: String): LookupElementBuilder {
private fun conditionLookupElement(keyword: String): LookupElementBuilder {
return LookupElementBuilder
.create(lookupString)
.create(keyword)
.bold()
.withTailText(" {...}")
.withInsertHandler { context, _ ->
val isLetExpr = context.file.findElementAt(context.tailOffset - 1)
?.ancestorStrict<RsLetDecl>()
?.let { it.expr?.text == lookupString } == true
val hasSemicolon = context.nextCharIs(';')

var tail = " { }"
if (isLetExpr && !hasSemicolon) tail += ';'
context.document.insertString(context.selectionEndOffset, tail)
EditorModificationUtil.moveCaretRelatively(context.editor, 1)
conditionLookupElementHandleInsert(context, keyword)
}
}

private fun conditionLookupElementHandleInsert(context: InsertionContext, keyword: String) {
val element0: RsExpr = when (keyword) {
"if", "else if" -> context.getElementOfType<RsIfExpr>()
"match" -> context.getElementOfType<RsMatchExpr>()
else -> null
} ?: return
val elementPointer = element0.createSmartPointer()

val semicolon = if (element0.parent is RsLetDecl && !context.nextCharIs(';')) ";" else ""
// `f` is condition expr which will be replaced by template builder
context.document.insertString(context.selectionEndOffset, " f { }$semicolon")
PsiDocumentManager.getInstance(context.project).commitDocument(context.document)

val element1 = elementPointer.element ?: return
val expr = when (element1) {
is RsIfExpr -> element1.condition?.expr
is RsMatchExpr -> element1.expr
else -> null
} ?: return
context.editor.newTemplateBuilder(element1)
.replaceElement(expr, MacroCallNode(CompleteMacro()))
.runInline {
val element2 = elementPointer.element ?: return@runInline
context.editor.moveCaretToOffset(element2, element2.endOffset - " }".length)
if (element2 is RsMatchExpr && !DumbService.isDumb(element2.project)) {
runWriteAction {
fillMatchArms(element2, context.editor)
}
}
}
}

Expand Down
51 changes: 51 additions & 0 deletions src/test/kotlin/org/rust/lang/core/completion/RsCompletionTest.kt
Expand Up @@ -1618,4 +1618,55 @@ class RsCompletionTest : RsCompletionTestBase() {
fun `test no fn main inside other function 3`() = checkNoCompletion("""
fn func() /*caret*/{}
""")

fun `test match completion 1`() = doSingleCompletionWithLiveTemplate("""
enum E { A, B }
fn main() {
let x = E::A;
mat/*caret*/
}
""", "x\t", """
enum E { A, B }
fn main() {
let x = E::A;
match x {
E::A => {/*caret*/}
E::B => {}
}
}
""")

fun `test match completion 2`() = doSingleCompletionWithLiveTemplate("""
enum E { A, B(i32) }
fn main() {
let x = E::A;
mat/*caret*/
}
""", "x\ty\t", """
enum E { A, B(i32) }
fn main() {
let x = E::A;
match x {
E::A => {/*caret*/}
E::B(y) => {}
}
}
""")

fun `test match completion 3`() = doSingleCompletionWithLiveTemplate("""
enum E { A, B(i32) }
fn main() {
let x = E::A;
let a = mat/*caret*/
}
""", "x\ty\t", """
enum E { A, B(i32) }
fn main() {
let x = E::A;
let a = match x {
E::A => {/*caret*/}
E::B(y) => {}
};
}
""")
}
Expand Up @@ -6,6 +6,7 @@
package org.rust.lang.core.completion

import com.intellij.codeInsight.lookup.LookupElement
import com.intellij.codeInsight.lookup.LookupElementPresentation
import com.intellij.psi.PsiElement
import com.intellij.psi.util.PsiTreeUtil
import org.intellij.lang.annotations.Language
Expand Down Expand Up @@ -68,6 +69,14 @@ abstract class RsCompletionTestBase(private val defaultFileName: String = "main.
@Language("Rust") after: String
) = completionFixture.doSingleCompletionByFileTree(before, after)

protected fun doSingleCompletionWithLiveTemplate(
@Language("Rust") before: String,
toType: String,
@Language("Rust") after: String
) = checkByTextWithLiveTemplate(before, after, toType) {
executeSoloCompletion()
}

protected fun checkContainsCompletion(
variant: String,
@Language("Rust") code: String,
Expand Down Expand Up @@ -104,6 +113,22 @@ abstract class RsCompletionTestBase(private val defaultFileName: String = "main.
completionChar: Char = '\n'
) = completionFixture.checkCompletion(lookupString, before, after, completionChar)

fun checkCompletionWithLiveTemplate(
lookupString: String,
@Language("Rust") before: String,
toType: String,
@Language("Rust") after: String
) {
checkByTextWithLiveTemplate(before.trimIndent(), after.trimIndent(), toType) {
val items = myFixture.completeBasic()!!
val lookupItem = items.find {
it.presentation.itemText == lookupString
} ?: error("Lookup string $lookupString not found")
myFixture.lookup.currentItem = lookupItem
myFixture.type('\n')
}
}

protected fun checkNotContainsCompletion(
variant: String,
@Language("Rust") code: String,
Expand Down Expand Up @@ -157,3 +182,6 @@ abstract class RsCompletionTestBase(private val defaultFileName: String = "main.
RIGHT
}
}

val LookupElement.presentation: LookupElementPresentation
get() = LookupElementPresentation().also { renderElement(it) }
Expand Up @@ -526,13 +526,13 @@ class RsKeywordCompletionContributorTest : RsCompletionTestBase() {
}
""")

fun `test else if`() = checkCompletion("else if", """
fun `test else if`() = checkCompletionWithLiveTemplate("else if", """
fn main() {
if true { } /*caret*/
}
""", """
""", "foo\t", """
fn main() {
if true { } else if /*caret*/ { }
if true { } else if foo { /*caret*/ }
}
""")

Expand Down Expand Up @@ -695,43 +695,43 @@ class RsKeywordCompletionContributorTest : RsCompletionTestBase() {
impl<T> Foo<T> for Bar where /*caret*/
""")

fun `test if or match in start of statement`() = checkCompletion(CONDITION_KEYWORDS, """
fun `test if or match in start of statement`() = checkCompletionWithLiveTemplate(CONDITION_KEYWORDS, """
fn foo() {
/*caret*/
}
""", """
""", "foo\t", """
fn foo() {
/*lookup*/ /*caret*/ { }
/*lookup*/ foo { /*caret*/ }
}
""")

fun `test if or match in let statement`() = checkCompletion(CONDITION_KEYWORDS, """
fun `test if or match in let statement`() = checkCompletionWithLiveTemplate(CONDITION_KEYWORDS, """
fn foo() {
let x = /*caret*/
}
""", """
""", "foo\t", """
fn foo() {
let x = /*lookup*/ /*caret*/ { };
let x = /*lookup*/ foo { /*caret*/ };
}
""")

fun `test if or match in let statement with semicolon`() = checkCompletion(CONDITION_KEYWORDS, """
fun `test if or match in let statement with semicolon`() = checkCompletionWithLiveTemplate(CONDITION_KEYWORDS, """
fn foo() {
let x = /*caret*/;
}
""", """
""", "foo\t", """
fn foo() {
let x = /*lookup*/ /*caret*/ { };
let x = /*lookup*/ foo { /*caret*/ };
}
""")

fun `test if or match in expression`() = checkCompletion(CONDITION_KEYWORDS, """
fun `test if or match in expression`() = checkCompletionWithLiveTemplate(CONDITION_KEYWORDS, """
fn foo() {
let x = 1 + /*caret*/
}
""", """
""", "foo\t", """
fn foo() {
let x = 1 + /*lookup*/ /*caret*/ { }
let x = 1 + /*lookup*/ foo { /*caret*/ }
}
""")

Expand Down Expand Up @@ -1229,6 +1229,17 @@ class RsKeywordCompletionContributorTest : RsCompletionTestBase() {
}
}

private fun checkCompletionWithLiveTemplate(
lookupStrings: List<String>,
@Language("Rust") before: String,
toType: String,
@Language("Rust") after: String
) {
for (lookupString in lookupStrings) {
checkCompletionWithLiveTemplate(lookupString, before, toType, after.replace("/*lookup*/", lookupString))
}
}

companion object {
private val MEMBERS_KEYWORDS = listOf("fn", "type", "const", "unsafe")
}
Expand Down
Expand Up @@ -5,7 +5,6 @@

package org.rust.lang.core.completion

import com.intellij.codeInsight.lookup.LookupElementPresentation
import org.intellij.lang.annotations.Language
import org.rust.ProjectDescriptor
import org.rust.WithStdlibRustProjectDescriptor
Expand Down Expand Up @@ -189,9 +188,7 @@ class RsLambdaExprCompletionTest : RsCompletionTestBase() {
private fun checkCompletion() {
val items = myFixture.completeBasic()
val item = items.find {
val presentation = LookupElementPresentation()
it.renderElement(presentation)
presentation.itemText == "|| {}"
it.presentation.itemText == "|| {}"
} ?: error("No lambda completion found")
myFixture.lookup.currentItem = item
myFixture.type('\n')
Expand Down
Expand Up @@ -9,7 +9,6 @@ import com.intellij.codeInsight.completion.CompletionResultSet
import com.intellij.codeInsight.completion.CompletionSorter
import com.intellij.codeInsight.completion.PrefixMatcher
import com.intellij.codeInsight.lookup.LookupElement
import com.intellij.codeInsight.lookup.LookupElementPresentation
import com.intellij.patterns.ElementPattern
import com.intellij.psi.NavigatablePsiElement
import com.intellij.util.containers.MultiMap
Expand Down Expand Up @@ -172,9 +171,7 @@ class RsLookupElementTest : RsTestBase() {
SimpleScopeEntry("foo", myFixture.file as RsFile, TYPES),
RsCompletionContext()
)
val presentation = LookupElementPresentation()

lookup.renderElement(presentation)
val presentation = lookup.presentation
assertNotNull(presentation.icon)
assertEquals("foo", presentation.itemText)
}
Expand Down Expand Up @@ -406,8 +403,7 @@ class RsLookupElementTest : RsTestBase() {
isBold: Boolean,
isStrikeout: Boolean
) {
val presentation = LookupElementPresentation()
lookup.renderElement(presentation)
val presentation = lookup.presentation

assertNotNull("Item icon should be not null", presentation.icon)
assertEquals("Tail text mismatch", tailText, presentation.tailText)
Expand Down