-
Notifications
You must be signed in to change notification settings - Fork 380
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#6411: Provide Rust-specific element features
Now we provide these two features: * Categorical feature `kind` to consider keyword/PSI type of element * Binary feature `is_from_stdlib` to consider if element's origin is stdlib
- Loading branch information
1 parent
91bc8ea
commit 1048244
Showing
4 changed files
with
267 additions
and
0 deletions.
There are no files selected for viewing
130 changes: 130 additions & 0 deletions
130
ml-completion/src/main/kotlin/org/rust/ml/RsElementFeatureProvider.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
/* | ||
* Use of this source code is governed by the MIT license that can be | ||
* found in the LICENSE file. | ||
*/ | ||
|
||
package org.rust.ml | ||
|
||
import com.intellij.codeInsight.completion.CompletionLocation | ||
import com.intellij.codeInsight.completion.ml.ContextFeatures | ||
import com.intellij.codeInsight.completion.ml.ElementFeatureProvider | ||
import com.intellij.codeInsight.completion.ml.MLFeatureValue | ||
import com.intellij.codeInsight.lookup.LookupElement | ||
import org.rust.ide.utils.import.isStd | ||
import org.rust.lang.core.psi.* | ||
import org.rust.lang.core.psi.ext.RsElement | ||
import org.rust.lang.core.psi.ext.containingCrate | ||
import kotlin.reflect.KClass | ||
|
||
@Suppress("UnstableApiUsage") | ||
class RsElementFeatureProvider : ElementFeatureProvider { | ||
override fun getName(): String = "rust" | ||
|
||
override fun calculateFeatures( | ||
element: LookupElement, | ||
location: CompletionLocation, | ||
contextFeatures: ContextFeatures | ||
): Map<String, MLFeatureValue> { | ||
val result = hashMapOf<String, MLFeatureValue>() | ||
val lookupString = element.lookupString | ||
|
||
/** If [element] is a keyword, store keyword kind as "kind" feature and finish */ | ||
val keywordKind = RsKeywordMLKind.from(lookupString) | ||
if (keywordKind != null) { | ||
result["kind"] = MLFeatureValue.categorical(keywordKind) | ||
return result | ||
} | ||
|
||
/** | ||
* Otherwise, if [element] is [RsElement], | ||
* store PSI kind as "kind" feature | ||
* and store if its origin is stdlib as "is_from_stdlib" feature | ||
*/ | ||
val psiElement = element.psiElement as? RsElement ?: return result | ||
val rsPsiElementKind = RsPsiElementMLKind.from(psiElement) | ||
if (rsPsiElementKind != null) { | ||
result[Features.Kind] = MLFeatureValue.categorical(rsPsiElementKind) | ||
} | ||
val containingCrate = psiElement.containingMod.containingCrate | ||
if (containingCrate != null) { | ||
result[Features.IsFromStdlib] = MLFeatureValue.binary(containingCrate.isStd) | ||
} | ||
|
||
return result | ||
} | ||
|
||
private object Features { | ||
const val Kind: String = "kind" | ||
const val IsFromStdlib: String = "is_from_stdlib" | ||
} | ||
} | ||
|
||
/** Should be synchronized with [org.rust.lang.core.psi.RsTokenTypeKt#RS_KEYWORDS] */ | ||
internal enum class RsKeywordMLKind(val lookupString: String) { | ||
As("as"), | ||
Box("box"), Break("break"), | ||
Const("const"), Continue("continue"), Crate("crate"), CSelf("Self"), | ||
Default("default"), | ||
Else("else"), Enum("enum"), Extern("extern"), | ||
Fn("fn"), For("for"), | ||
If("if"), Impl("impl"), In("in"), | ||
Macro("macro"), | ||
Let("let"), Loop("loop"), | ||
Match("match"), Mod("mod"), Move("move"), Mut("mut"), | ||
Pub("pub"), | ||
Ref("ref"), Return("return"), | ||
Self("self"), Static("static"), Struct("struct"), Super("super"), | ||
Trait("trait"), Type("type"), | ||
Union("union"), Unsafe("unsafe"), Use("use"), | ||
Where("where"), While("while"), | ||
Yield("yield"); | ||
|
||
companion object { | ||
fun from(lookupString: String): RsKeywordMLKind? { | ||
return values().find { it.lookupString == lookupString } | ||
} | ||
} | ||
} | ||
|
||
@Suppress("unused") | ||
internal enum class RsPsiElementMLKind(val klass: KClass<out RsElement>) { | ||
PatBinding(RsPatBinding::class), | ||
Function(RsFunction::class), | ||
StructItem(RsStructItem::class), | ||
TraitItem(RsTraitItem::class), | ||
NamedFieldDecl(RsNamedFieldDecl::class), | ||
File(RsFile::class), | ||
EnumVariant(RsEnumVariant::class), | ||
SelfParameter(RsSelfParameter::class), | ||
Macro(RsMacro::class), | ||
EnumItem(RsEnumItem::class), | ||
TypeAlias(RsTypeAlias::class), | ||
LifetimeParameter(RsLifetimeParameter::class), | ||
Constant(RsConstant::class), | ||
ModItem(RsModItem::class), | ||
TupleFieldDecl(RsTupleFieldDecl::class), | ||
PathExpr(RsPathExpr::class), | ||
DotExpr(RsDotExpr::class), | ||
BaseType(RsBaseType::class), | ||
PatIdent(RsPatIdent::class), | ||
UseSpeck(RsUseSpeck::class), | ||
ImplItem(RsImplItem::class), | ||
StructLiteralBody(RsStructLiteralBody::class), | ||
MacroArgument(RsMacroArgument::class), | ||
MetaItem(RsMetaItem::class), | ||
BlockFields(RsBlockFields::class), | ||
TraitRef(RsTraitRef::class), | ||
ValueArgumentList(RsValueArgumentList::class), | ||
Path(RsPath::class), | ||
FormatMacroArg(RsFormatMacroArg::class), | ||
MacroCall(RsMacroCall::class), | ||
StructLiteral(RsStructLiteral::class), | ||
RefLikeType(RsRefLikeType::class), | ||
TypeArgumentList(RsTypeArgumentList::class); | ||
|
||
companion object { | ||
fun from(element: RsElement): RsPsiElementMLKind? { | ||
return values().find { it.klass.isInstance(element) } | ||
} | ||
} | ||
} |
1 change: 1 addition & 0 deletions
1
ml-completion/src/main/resources/META-INF/ml-completion-only.xml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
<idea-plugin> | ||
<extensions defaultExtensionNs="com.intellij"> | ||
<completion.ml.ranking.features.policy language="Rust" implementationClass="org.rust.ml.RsCompletionFeaturesPolicy"/> | ||
<completion.ml.elementFeatures language="Rust" implementationClass="org.rust.ml.RsElementFeatureProvider"/> | ||
</extensions> | ||
</idea-plugin> |
131 changes: 131 additions & 0 deletions
131
ml-completion/src/test/kotlin/org/rust/ml/RsElementFeatureProviderTest.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
/* | ||
* Use of this source code is governed by the MIT license that can be | ||
* found in the LICENSE file. | ||
*/ | ||
|
||
package org.rust.ml | ||
|
||
import com.intellij.codeInsight.lookup.LookupManager | ||
import com.intellij.codeInsight.lookup.impl.LookupImpl | ||
import com.intellij.completion.ml.util.RelevanceUtil.asRelevanceMaps | ||
import com.intellij.testFramework.UsefulTestCase | ||
import org.intellij.lang.annotations.Language | ||
import org.junit.Test | ||
import org.rust.ProjectDescriptor | ||
import org.rust.RsTestBase | ||
import org.rust.WithStdlibRustProjectDescriptor | ||
import org.rust.lang.core.psi.RS_KEYWORDS | ||
|
||
/* | ||
* Use of this source code is governed by the MIT license that can be | ||
* found in the LICENSE file. | ||
*/ | ||
|
||
@Suppress("UnstableApiUsage") | ||
class RsElementFeatureProviderTest : RsTestBase() { | ||
fun `test top level keyword "kind" features`() = doTest("ml_rust_kind", """ | ||
/*caret*/ | ||
fn main() {} | ||
""", mapOf( | ||
"struct" to RsKeywordMLKind.Struct.name, | ||
"enum" to RsKeywordMLKind.Enum.name, | ||
"fn" to RsKeywordMLKind.Fn.name, | ||
"const" to RsKeywordMLKind.Const.name, | ||
"pub" to RsKeywordMLKind.Pub.name, | ||
"extern" to RsKeywordMLKind.Extern.name, | ||
"trait" to RsKeywordMLKind.Trait.name, | ||
"type" to RsKeywordMLKind.Type.name, | ||
"use" to RsKeywordMLKind.Use.name, | ||
"static" to RsKeywordMLKind.Static.name, | ||
)) | ||
|
||
fun `test body keyword "kind" features`() = doTest("ml_rust_kind", """ | ||
fn main() { | ||
/*caret*/ | ||
} | ||
""", mapOf( | ||
"let" to RsKeywordMLKind.Let.name, | ||
"struct" to RsKeywordMLKind.Struct.name, | ||
"enum" to RsKeywordMLKind.Enum.name, | ||
"if" to RsKeywordMLKind.If.name, | ||
"match" to RsKeywordMLKind.Match.name, | ||
"return" to RsKeywordMLKind.Return.name, | ||
"crate" to RsKeywordMLKind.Crate.name, | ||
)) | ||
|
||
fun `test named elements "kind" features`() = doTest("ml_rust_kind", """ | ||
fn foo_func() {} | ||
fn main() { | ||
let foo_var = 42; | ||
f/*caret*/ | ||
} | ||
""", mapOf( | ||
"foo_var" to RsPsiElementMLKind.PatBinding.name, | ||
"foo_func" to RsPsiElementMLKind.Function.name, | ||
)) | ||
|
||
fun `test struct field method "kind" feature`() = doTest("ml_rust_kind", """ | ||
struct S { field1: i32, field2: i32 } | ||
impl S { | ||
fn foo(&self) {} | ||
} | ||
fn foo(s: S) { | ||
s.f/*caret*/ | ||
} | ||
""", mapOf( | ||
"field1" to RsPsiElementMLKind.NamedFieldDecl.name, | ||
"foo" to RsPsiElementMLKind.Function.name, | ||
)) | ||
|
||
@ProjectDescriptor(WithStdlibRustProjectDescriptor::class) | ||
fun `test "is_from_stdlib" feature`() = doTest("ml_rust_is_from_stdlib", """ | ||
fn my_print() {} | ||
fn main() { | ||
prin/*caret*/ | ||
} | ||
""", mapOf( | ||
"println" to "1", | ||
"my_print" to "0", | ||
)) | ||
|
||
@Test | ||
fun `test all keywords are covered`() { | ||
val kindsKeywords = RsKeywordMLKind.values().map { | ||
it.lookupString | ||
} | ||
val actualKeywords = RS_KEYWORDS.types.map { | ||
when (val name = it.toString()) { | ||
"default_kw" -> "default" | ||
"union_kw" -> "union" | ||
else -> name | ||
} | ||
} | ||
UsefulTestCase.assertSameElements(kindsKeywords, actualKeywords) | ||
} | ||
|
||
|
||
private fun doTest(feature: String, @Language("Rust") code: String, values: Map<String, Any>) { | ||
InlineFile(code.trimIndent()).withCaret() | ||
myFixture.completeBasic() | ||
val lookup = LookupManager.getInstance(project).activeLookup as LookupImpl | ||
for ((lookupString, expectedValue) in values) { | ||
checkFeatureValue(lookup, feature, lookupString, expectedValue) | ||
} | ||
} | ||
|
||
private fun checkFeatureValue( | ||
lookup: LookupImpl, | ||
feature: String, | ||
lookupString: String, | ||
expectedValue: Any | ||
) { | ||
val items = lookup.items | ||
val allRelevanceObjects = lookup.getRelevanceObjects(items, false) | ||
|
||
val matchedItem = items.firstOrNull { it.lookupString == lookupString } ?: error("No `$lookupString` in lookup") | ||
val relevanceObjects = allRelevanceObjects[matchedItem].orEmpty() | ||
val featuresMap = asRelevanceMaps(relevanceObjects).second | ||
val actualValue = featuresMap[feature] | ||
assertEquals("Invalid value for `$feature` of `$lookupString`", expectedValue, actualValue) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
<idea-plugin xmlns:xi="http://www.w3.org/2001/XInclude" allow-bundled-update="true"> | ||
<id>org.rust.ml</id> | ||
<xi:include href="/META-INF/rust-core.xml" xpointer="xpointer(/idea-plugin/*)"/> | ||
<depends optional="true" config-file="ml-completion-only.xml">com.intellij.stats.completion</depends> | ||
</idea-plugin> |