Skip to content

Commit

Permalink
#6411: Provide Rust-specific element features
Browse files Browse the repository at this point in the history
Now we provide these two features:
* Categorical feature `kind` to consider keyword/PSI type of element
* Binary feature `is_from_stdlib` to consider if element's origin is stdlib
  • Loading branch information
artemmukhin committed Nov 30, 2020
1 parent 91bc8ea commit 4571960
Show file tree
Hide file tree
Showing 4 changed files with 265 additions and 0 deletions.
130 changes: 130 additions & 0 deletions ml-completion/src/main/kotlin/org/rust/ml/RsElementFeatureProvider.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Use of this source code is governed by the MIT license that can be
* found in the LICENSE file.
*/

package org.rust.ml

import com.intellij.codeInsight.completion.CompletionLocation
import com.intellij.codeInsight.completion.ml.ContextFeatures
import com.intellij.codeInsight.completion.ml.ElementFeatureProvider
import com.intellij.codeInsight.completion.ml.MLFeatureValue
import com.intellij.codeInsight.lookup.LookupElement
import org.rust.ide.utils.import.isStd
import org.rust.lang.core.psi.*
import org.rust.lang.core.psi.ext.RsElement
import org.rust.lang.core.psi.ext.containingCrate
import kotlin.reflect.KClass

@Suppress("UnstableApiUsage")
class RsElementFeatureProvider : ElementFeatureProvider {
override fun getName(): String = "rust"

override fun calculateFeatures(
element: LookupElement,
location: CompletionLocation,
contextFeatures: ContextFeatures
): Map<String, MLFeatureValue> {
val result = hashMapOf<String, MLFeatureValue>()
val lookupString = element.lookupString

/** If [element] is a keyword, store keyword kind as [KIND] feature and finish */
val keywordKind = RsKeywordMLKind.from(lookupString)
if (keywordKind != null) {
result[KIND] = MLFeatureValue.categorical(keywordKind)
return result
}

/**
* Otherwise, if [element] is [RsElement],
* store PSI kind as [KIND] feature
* and store if its origin is stdlib as [IS_FROM_STDLIB] feature
*/
val psiElement = element.psiElement as? RsElement ?: return result
val psiElementKind = RsPsiElementMLKind.from(psiElement)
if (psiElementKind != null) {
result[KIND] = MLFeatureValue.categorical(psiElementKind)
}
val containingCrate = psiElement.containingMod.containingCrate
if (containingCrate != null) {
result[IS_FROM_STDLIB] = MLFeatureValue.binary(containingCrate.isStd)
}

return result
}

companion object {
private const val KIND: String = "kind"
private const val IS_FROM_STDLIB: String = "is_from_stdlib"
}
}

/** Should be synchronized with [org.rust.lang.core.psi.RsTokenTypeKt#RS_KEYWORDS] */
internal enum class RsKeywordMLKind(val lookupString: String) {
As("as"),
Box("box"), Break("break"),
Const("const"), Continue("continue"), Crate("crate"), CSelf("Self"),
Default("default"),
Else("else"), Enum("enum"), Extern("extern"),
Fn("fn"), For("for"),
If("if"), Impl("impl"), In("in"),
Macro("macro"),
Let("let"), Loop("loop"),
Match("match"), Mod("mod"), Move("move"), Mut("mut"),
Pub("pub"),
Ref("ref"), Return("return"),
Self("self"), Static("static"), Struct("struct"), Super("super"),
Trait("trait"), Type("type"),
Union("union"), Unsafe("unsafe"), Use("use"),
Where("where"), While("while"),
Yield("yield");

companion object {
fun from(lookupString: String): RsKeywordMLKind? {
return values().find { it.lookupString == lookupString }
}
}
}

@Suppress("unused")
internal enum class RsPsiElementMLKind(val klass: KClass<out RsElement>) {
PatBinding(RsPatBinding::class),
Function(RsFunction::class),
StructItem(RsStructItem::class),
TraitItem(RsTraitItem::class),
NamedFieldDecl(RsNamedFieldDecl::class),
File(RsFile::class),
EnumVariant(RsEnumVariant::class),
SelfParameter(RsSelfParameter::class),
Macro(RsMacro::class),
EnumItem(RsEnumItem::class),
TypeAlias(RsTypeAlias::class),
LifetimeParameter(RsLifetimeParameter::class),
Constant(RsConstant::class),
ModItem(RsModItem::class),
TupleFieldDecl(RsTupleFieldDecl::class),
PathExpr(RsPathExpr::class),
DotExpr(RsDotExpr::class),
BaseType(RsBaseType::class),
PatIdent(RsPatIdent::class),
UseSpeck(RsUseSpeck::class),
ImplItem(RsImplItem::class),
StructLiteralBody(RsStructLiteralBody::class),
MacroArgument(RsMacroArgument::class),
MetaItem(RsMetaItem::class),
BlockFields(RsBlockFields::class),
TraitRef(RsTraitRef::class),
ValueArgumentList(RsValueArgumentList::class),
Path(RsPath::class),
FormatMacroArg(RsFormatMacroArg::class),
MacroCall(RsMacroCall::class),
StructLiteral(RsStructLiteral::class),
RefLikeType(RsRefLikeType::class),
TypeArgumentList(RsTypeArgumentList::class);

companion object {
fun from(element: RsElement): RsPsiElementMLKind? {
return values().find { it.klass.isInstance(element) }
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
<idea-plugin>
<extensions defaultExtensionNs="com.intellij">
<completion.ml.ranking.features.policy language="Rust" implementationClass="org.rust.ml.RsCompletionFeaturesPolicy"/>
<completion.ml.elementFeatures language="Rust" implementationClass="org.rust.ml.RsElementFeatureProvider"/>
</extensions>
</idea-plugin>
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Use of this source code is governed by the MIT license that can be
* found in the LICENSE file.
*/

package org.rust.ml

import com.intellij.codeInsight.lookup.LookupManager
import com.intellij.codeInsight.lookup.impl.LookupImpl
import com.intellij.completion.ml.util.RelevanceUtil.asRelevanceMaps
import com.intellij.testFramework.UsefulTestCase
import org.intellij.lang.annotations.Language
import org.rust.ProjectDescriptor
import org.rust.RsTestBase
import org.rust.WithStdlibRustProjectDescriptor
import org.rust.lang.core.psi.RS_KEYWORDS

/*
* Use of this source code is governed by the MIT license that can be
* found in the LICENSE file.
*/

@Suppress("UnstableApiUsage")
class RsElementFeatureProviderTest : RsTestBase() {
fun `test top level keyword "kind" features`() = doTest("ml_rust_kind", """
/*caret*/
fn main() {}
""", mapOf(
"struct" to RsKeywordMLKind.Struct.name,
"enum" to RsKeywordMLKind.Enum.name,
"fn" to RsKeywordMLKind.Fn.name,
"const" to RsKeywordMLKind.Const.name,
"pub" to RsKeywordMLKind.Pub.name,
"extern" to RsKeywordMLKind.Extern.name,
"trait" to RsKeywordMLKind.Trait.name,
"type" to RsKeywordMLKind.Type.name,
"use" to RsKeywordMLKind.Use.name,
"static" to RsKeywordMLKind.Static.name,
))

fun `test body keyword "kind" features`() = doTest("ml_rust_kind", """
fn main() {
/*caret*/
}
""", mapOf(
"let" to RsKeywordMLKind.Let.name,
"struct" to RsKeywordMLKind.Struct.name,
"enum" to RsKeywordMLKind.Enum.name,
"if" to RsKeywordMLKind.If.name,
"match" to RsKeywordMLKind.Match.name,
"return" to RsKeywordMLKind.Return.name,
"crate" to RsKeywordMLKind.Crate.name,
))

fun `test named elements "kind" features`() = doTest("ml_rust_kind", """
fn foo_func() {}
fn main() {
let foo_var = 42;
f/*caret*/
}
""", mapOf(
"foo_var" to RsPsiElementMLKind.PatBinding.name,
"foo_func" to RsPsiElementMLKind.Function.name,
))

fun `test struct field method "kind" feature`() = doTest("ml_rust_kind", """
struct S { field1: i32, field2: i32 }
impl S {
fn foo(&self) {}
}
fn foo(s: S) {
s.f/*caret*/
}
""", mapOf(
"field1" to RsPsiElementMLKind.NamedFieldDecl.name,
"foo" to RsPsiElementMLKind.Function.name,
))

@ProjectDescriptor(WithStdlibRustProjectDescriptor::class)
fun `test "is_from_stdlib" feature`() = doTest("ml_rust_is_from_stdlib", """
fn my_print() {}
fn main() {
prin/*caret*/
}
""", mapOf(
"println" to "1",
"my_print" to "0",
))

fun `test all keywords are covered`() {
val kindsKeywords = RsKeywordMLKind.values().map {
it.lookupString
}
val actualKeywords = RS_KEYWORDS.types.map {
when (val name = it.toString()) {
"default_kw" -> "default"
"union_kw" -> "union"
else -> name
}
}
UsefulTestCase.assertSameElements(kindsKeywords, actualKeywords)
}


private fun doTest(feature: String, @Language("Rust") code: String, values: Map<String, Any>) {
InlineFile(code.trimIndent()).withCaret()
myFixture.completeBasic()
val lookup = LookupManager.getInstance(project).activeLookup as LookupImpl
for ((lookupString, expectedValue) in values) {
checkFeatureValue(lookup, feature, lookupString, expectedValue)
}
}

private fun checkFeatureValue(
lookup: LookupImpl,
feature: String,
lookupString: String,
expectedValue: Any
) {
val items = lookup.items
val allRelevanceObjects = lookup.getRelevanceObjects(items, false)

val matchedItem = items.firstOrNull { it.lookupString == lookupString } ?: error("No `$lookupString` in lookup")
val relevanceObjects = allRelevanceObjects[matchedItem].orEmpty()
val featuresMap = asRelevanceMaps(relevanceObjects).second
val actualValue = featuresMap[feature]
assertEquals("Invalid value for `$feature` of `$lookupString`", expectedValue, actualValue)
}
}
5 changes: 5 additions & 0 deletions ml-completion/src/test/resources/META-INF/plugin.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<idea-plugin xmlns:xi="http://www.w3.org/2001/XInclude" allow-bundled-update="true">
<id>org.rust.ml</id>
<xi:include href="/META-INF/rust-core.xml" xpointer="xpointer(/idea-plugin/*)"/>
<depends optional="true" config-file="ml-completion-only.xml">com.intellij.stats.completion</depends>
</idea-plugin>

0 comments on commit 4571960

Please sign in to comment.