Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide ML completion ranking for Rust #6419

Merged
merged 4 commits into from Dec 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTING.md
Expand Up @@ -198,6 +198,7 @@ The current Rust plugin modules:
* `:coverage` - integration with [coverage](https://github.com/JetBrains/intellij-community/tree/master/plugins/coverage-common) plugin
* `:grazie` - integration with [grazie](https://plugins.jetbrains.com/plugin/12175-grazie) plugin
* `:js` - interop with JavaScript language
* `:ml-completion` - integration with [Machine Learning Code Completion](https://github.com/JetBrains/intellij-community/tree/master/plugins/completion-ml-ranking) plugin

If you want to implement integration with another plugin/IDE, you should create a new gradle module for that.

Expand Down
30 changes: 29 additions & 1 deletion build.gradle.kts
@@ -1,3 +1,4 @@
import groovy.json.JsonSlurper
import org.apache.tools.ant.taskdefs.condition.Os.*
import org.gradle.api.JavaVersion.VERSION_1_8
import org.gradle.api.internal.HasConvention
Expand All @@ -12,7 +13,6 @@ import org.jetbrains.intellij.tasks.RunIdeTask
import org.jetbrains.kotlin.gradle.plugin.KotlinSourceSet
import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
import org.jsoup.Jsoup
import groovy.json.JsonSlurper
import java.io.Writer
import java.net.URL
import kotlin.concurrent.thread
Expand Down Expand Up @@ -45,6 +45,7 @@ val javaPlugin = "java"
val javaScriptPlugin = "JavaScript"
// BACKCOMPAT: 2020.2
val clionPlugins = if (platformVersion < 203) emptyList() else listOf("com.intellij.cidr.base", "com.intellij.clion")
val mlCompletionPlugin = "com.intellij.completion.ml.ranking"
artemmukhin marked this conversation as resolved.
Show resolved Hide resolved

plugins {
idea
Expand Down Expand Up @@ -74,6 +75,7 @@ allprojects {
jcenter()
maven("https://dl.bintray.com/jetbrains/markdown")
maven("http://download.eclipse.org/jgit/maven")
maven("https://dl.bintray.com/jetbrains/intellij-third-party-dependencies")
}

idea {
Expand Down Expand Up @@ -194,6 +196,10 @@ project(":plugin") {
psiViewerPlugin,
javaScriptPlugin
)
// BACKCOMPAT: 2020.2
if (platformVersion >= 203) {
plugins += mlCompletionPlugin
}
if (baseIDE == "idea") {
plugins += listOf(
copyrightPlugin,
Expand All @@ -216,6 +222,7 @@ project(":plugin") {
implementation(project(":duplicates"))
implementation(project(":grazie"))
implementation(project(":js"))
implementation(project(":ml-completion"))
}

tasks {
Expand Down Expand Up @@ -450,6 +457,27 @@ project(":js") {
}
}

project(":ml-completion") {
intellij {
// BACKCOMPAT: 2020.2
if (platformVersion >= 203) {
val plugins = mutableListOf<Any>(mlCompletionPlugin)
// TODO: drop it when CLion move `navigation.class.hierarchy` property from c-plugin to CLion resources
if (baseIDE == "clion") {
plugins += "c-plugin"
}
setPlugins(*plugins.toTypedArray())
}
}
dependencies {
implementation("org.jetbrains.intellij.deps.completion:completion-ranking-rust:0.0.4")
implementation(project(":"))
implementation(project(":common"))
testImplementation(project(":", "testOutput"))
testImplementation(project(":common", "testOutput"))
}
}

project(":intellij-toml") {
version = "0.2.$patchVersion.${prop("buildNumber")}$versionSuffix"
intellij {
Expand Down
@@ -0,0 +1,12 @@
/*
* Use of this source code is governed by the MIT license that can be
* found in the LICENSE file.
*/

package org.rust.ml

import com.intellij.completion.ml.features.CompletionFeaturesPolicy

class RsCompletionFeaturesPolicy : CompletionFeaturesPolicy {
override fun useNgramModel(): Boolean = true
}
@@ -0,0 +1,130 @@
/*
* Use of this source code is governed by the MIT license that can be
* found in the LICENSE file.
*/

package org.rust.ml

import com.intellij.codeInsight.completion.CompletionLocation
import com.intellij.codeInsight.completion.ml.ContextFeatures
import com.intellij.codeInsight.completion.ml.ElementFeatureProvider
import com.intellij.codeInsight.completion.ml.MLFeatureValue
import com.intellij.codeInsight.lookup.LookupElement
import org.rust.ide.utils.import.isStd
import org.rust.lang.core.psi.*
import org.rust.lang.core.psi.ext.RsElement
import org.rust.lang.core.psi.ext.containingCrate
import kotlin.reflect.KClass

@Suppress("UnstableApiUsage")
class RsElementFeatureProvider : ElementFeatureProvider {
override fun getName(): String = "rust"

override fun calculateFeatures(
element: LookupElement,
location: CompletionLocation,
contextFeatures: ContextFeatures
): Map<String, MLFeatureValue> {
val result = hashMapOf<String, MLFeatureValue>()
val lookupString = element.lookupString

/** If [element] is a keyword, store keyword kind as [KIND] feature and finish */
val keywordKind = RsKeywordMLKind.from(lookupString)
if (keywordKind != null) {
result[KIND] = MLFeatureValue.categorical(keywordKind)
return result
}

/**
* Otherwise, if [element] is [RsElement],
* store PSI kind as [KIND] feature
* and store if its origin is stdlib as [IS_FROM_STDLIB] feature
*/
val psiElement = element.psiElement as? RsElement ?: return result
val psiElementKind = RsPsiElementMLKind.from(psiElement)
if (psiElementKind != null) {
result[KIND] = MLFeatureValue.categorical(psiElementKind)
}
val containingCrate = psiElement.containingMod.containingCrate
if (containingCrate != null) {
result[IS_FROM_STDLIB] = MLFeatureValue.binary(containingCrate.isStd)
}

return result
}

companion object {
private const val KIND: String = "kind"
private const val IS_FROM_STDLIB: String = "is_from_stdlib"
}
}

/** Should be synchronized with [org.rust.lang.core.psi.RsTokenTypeKt#RS_KEYWORDS] */
internal enum class RsKeywordMLKind(val lookupString: String) {
As("as"), Async("async"), Auto("auto"),
Box("box"), Break("break"),
Const("const"), Continue("continue"), Crate("crate"), CSelf("Self"),
Default("default"), Dyn("dyn"),
Else("else"), Enum("enum"), Extern("extern"),
Fn("fn"), For("for"),
If("if"), Impl("impl"), In("in"),
Macro("macro"),
Let("let"), Loop("loop"),
Match("match"), Mod("mod"), Move("move"), Mut("mut"),
Pub("pub"),
Raw("raw"), Ref("ref"), Return("return"),
Self("self"), Static("static"), Struct("struct"), Super("super"),
Trait("trait"), Type("type"),
Union("union"), Unsafe("unsafe"), Use("use"),
Where("where"), While("while"),
Yield("yield");

companion object {
fun from(lookupString: String): RsKeywordMLKind? {
return values().find { it.lookupString == lookupString }
}
}
}

@Suppress("unused")
internal enum class RsPsiElementMLKind(val klass: KClass<out RsElement>) {
PatBinding(RsPatBinding::class),
Function(RsFunction::class),
StructItem(RsStructItem::class),
TraitItem(RsTraitItem::class),
NamedFieldDecl(RsNamedFieldDecl::class),
File(RsFile::class),
EnumVariant(RsEnumVariant::class),
SelfParameter(RsSelfParameter::class),
Macro(RsMacro::class),
EnumItem(RsEnumItem::class),
TypeAlias(RsTypeAlias::class),
LifetimeParameter(RsLifetimeParameter::class),
Constant(RsConstant::class),
ModItem(RsModItem::class),
TupleFieldDecl(RsTupleFieldDecl::class),
PathExpr(RsPathExpr::class),
DotExpr(RsDotExpr::class),
BaseType(RsBaseType::class),
PatIdent(RsPatIdent::class),
UseSpeck(RsUseSpeck::class),
ImplItem(RsImplItem::class),
StructLiteralBody(RsStructLiteralBody::class),
MacroArgument(RsMacroArgument::class),
MetaItem(RsMetaItem::class),
BlockFields(RsBlockFields::class),
TraitRef(RsTraitRef::class),
ValueArgumentList(RsValueArgumentList::class),
Path(RsPath::class),
FormatMacroArg(RsFormatMacroArg::class),
MacroCall(RsMacroCall::class),
StructLiteral(RsStructLiteral::class),
RefLikeType(RsRefLikeType::class),
TypeArgumentList(RsTypeArgumentList::class);

companion object {
fun from(element: RsElement): RsPsiElementMLKind? {
return values().find { it.klass.isInstance(element) }
}
}
}
@@ -0,0 +1,15 @@
/*
* Use of this source code is governed by the MIT license that can be
* found in the LICENSE file.
*/

package org.rust.ml

import com.intellij.internal.ml.catboost.CatBoostJarCompletionModelProvider
import com.intellij.lang.Language
import org.rust.lang.RsLanguage

@Suppress("UnstableApiUsage")
class RsMLRankingProvider : CatBoostJarCompletionModelProvider("Rust", "rust_features", "rust_model") {
override fun isLanguageSupported(language: Language): Boolean = language == RsLanguage
}
@@ -0,0 +1,8 @@
<idea-plugin>
<depends>com.intellij.completion.ml.ranking</depends>
<extensions defaultExtensionNs="com.intellij">
<completion.ml.ranking.features.policy language="Rust" implementationClass="org.rust.ml.RsCompletionFeaturesPolicy"/>
<completion.ml.elementFeatures language="Rust" implementationClass="org.rust.ml.RsElementFeatureProvider"/>
<completion.ml.model implementation="org.rust.ml.RsMLRankingProvider"/>
</extensions>
</idea-plugin>
@@ -0,0 +1,128 @@
/*
* Use of this source code is governed by the MIT license that can be
* found in the LICENSE file.
*/

package org.rust.ml

import com.intellij.codeInsight.lookup.LookupManager
import com.intellij.codeInsight.lookup.impl.LookupImpl
import com.intellij.completion.ml.util.RelevanceUtil.asRelevanceMaps
import com.intellij.testFramework.UsefulTestCase
import org.intellij.lang.annotations.Language
import org.rust.ProjectDescriptor
import org.rust.RsTestBase
import org.rust.WithStdlibRustProjectDescriptor
import org.rust.lang.core.psi.RS_KEYWORDS

@Suppress("UnstableApiUsage")
class RsElementFeatureProviderTest : RsTestBase() {
fun `test top level keyword kind features`() = doTest("ml_rust_kind", """
/*caret*/
fn main() {}
""", mapOf(
"struct" to RsKeywordMLKind.Struct.name,
"enum" to RsKeywordMLKind.Enum.name,
"fn" to RsKeywordMLKind.Fn.name,
"const" to RsKeywordMLKind.Const.name,
"pub" to RsKeywordMLKind.Pub.name,
"extern" to RsKeywordMLKind.Extern.name,
"trait" to RsKeywordMLKind.Trait.name,
"type" to RsKeywordMLKind.Type.name,
"use" to RsKeywordMLKind.Use.name,
"static" to RsKeywordMLKind.Static.name,
))

fun `test body keyword kind features`() = doTest("ml_rust_kind", """
fn main() {
/*caret*/
}
""", mapOf(
"let" to RsKeywordMLKind.Let.name,
"struct" to RsKeywordMLKind.Struct.name,
"enum" to RsKeywordMLKind.Enum.name,
"if" to RsKeywordMLKind.If.name,
"match" to RsKeywordMLKind.Match.name,
"return" to RsKeywordMLKind.Return.name,
"crate" to RsKeywordMLKind.Crate.name,
))

fun `test named elements kind features`() = doTest("ml_rust_kind", """
fn foo_func() {}
fn main() {
let foo_var = 42;
f/*caret*/
}
""", mapOf(
"foo_var" to RsPsiElementMLKind.PatBinding.name,
"foo_func" to RsPsiElementMLKind.Function.name,
))

fun `test struct field method kind feature`() = doTest("ml_rust_kind", """
struct S { field1: i32, field2: i32 }
impl S {
fn foo(&self) {}
}
fn foo(s: S) {
s.f/*caret*/
}
""", mapOf(
"field1" to RsPsiElementMLKind.NamedFieldDecl.name,
"foo" to RsPsiElementMLKind.Function.name,
))

@ProjectDescriptor(WithStdlibRustProjectDescriptor::class)
fun `test is_from_stdlib feature`() = doTest("ml_rust_is_from_stdlib", """
fn my_print() {}
fn main() {
prin/*caret*/
}
""", mapOf(
"println" to "1",
"my_print" to "0",
))

fun `test all keywords are covered`() {
val kindsKeywords = RsKeywordMLKind.values().map {
it.lookupString
}
val actualKeywords = RS_KEYWORDS.types.map {
when (val name = it.toString()) {
"async_kw" -> "async"
"auto_kw" -> "auto"
"default_kw" -> "default"
"dyn_kw" -> "dyn"
"raw_kw" -> "raw"
"union_kw" -> "union"
else -> name
}
}
UsefulTestCase.assertSameElements(kindsKeywords, actualKeywords)
}


private fun doTest(feature: String, @Language("Rust") code: String, values: Map<String, Any>) {
InlineFile(code.trimIndent()).withCaret()
myFixture.completeBasic()
val lookup = LookupManager.getInstance(project).activeLookup as LookupImpl
for ((lookupString, expectedValue) in values) {
checkFeatureValue(lookup, feature, lookupString, expectedValue)
}
}

private fun checkFeatureValue(
lookup: LookupImpl,
feature: String,
lookupString: String,
expectedValue: Any
) {
val items = lookup.items
val allRelevanceObjects = lookup.getRelevanceObjects(items, false)

val matchedItem = items.firstOrNull { it.lookupString == lookupString } ?: error("No `$lookupString` in lookup")
val relevanceObjects = allRelevanceObjects[matchedItem].orEmpty()
val featuresMap = asRelevanceMaps(relevanceObjects).second
val actualValue = featuresMap[feature]
assertEquals("Invalid value for `$feature` of `$lookupString`", expectedValue, actualValue)
}
}
@@ -0,0 +1,13 @@
/*
* Use of this source code is governed by the MIT license that can be
* found in the LICENSE file.
*/

package org.rust.ml

import org.junit.Test

class RsModelMetadataConsistencyTest {
@Test
fun `test model metadata consistency`() = RsMLRankingProvider().assertModelMetadataConsistent()
}