Skip to content

Commit

Permalink
Merge #10757 #10758
Browse files Browse the repository at this point in the history
10757: Fixes add #[tokio::main] quick fix. r=ceronman a=ceronman

This changes the way we are installing tokio dependency. Instead of using `cargo add` we're editing Cargo.toml directly. This allows us to have more control over the features. New helper function to add cargo dependency is added as well.

changelog: Fixes undo not available after adding tokio dependency when using Add `#[tokio::main]` quickfix.


10758: T: Added some extra tests for value out of range inspection. r=ceronman a=ceronman



Co-authored-by: vlad20012 <beskvlad@gmail.com>
Co-authored-by: Manuel Ceron <manuel.ceron@jetbrains.com>
  • Loading branch information
3 people committed Jul 28, 2023
3 parents 8ac87df + 7ccca17 + 8ba3eb3 commit d024d4d
Show file tree
Hide file tree
Showing 10 changed files with 475 additions and 84 deletions.
43 changes: 17 additions & 26 deletions src/main/kotlin/org/rust/ide/fixes/AddTokioMainFix.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,50 +6,41 @@
package org.rust.ide.fixes

import com.intellij.openapi.editor.Editor
import com.intellij.openapi.progress.ProgressIndicator
import com.intellij.openapi.progress.Task
import com.intellij.openapi.fileEditor.FileDocumentManager
import com.intellij.openapi.project.Project
import org.rust.RsBundle
import org.rust.cargo.project.model.cargoProjects
import org.rust.cargo.project.settings.toolchain
import org.rust.cargo.toolchain.tools.cargo
import org.rust.lang.core.psi.RsFunction
import org.rust.lang.core.psi.RsPsiFactory
import org.rust.lang.core.psi.ext.*
import org.rust.openapiext.document
import org.rust.toml.addCargoDependency
import org.rust.toml.getPackageCargoTomlFile

class AddTokioMainFix(function: RsFunction) : RsQuickFixBase<RsFunction>(function) {
private val hasTokio = function.findDependencyCrateRoot("tokio") != null
override fun getFamilyName() = RsBundle.message("intention.name.add.tokio.main")
override fun getText(): String {
return if (hasTokio) {
RsBundle.message("intention.name.add.tokio.main")
} else {
RsBundle.message("intention.name.install.tokio.and.add.main")
}
override fun getFamilyName(): String = RsBundle.message("intention.name.add.tokio.main")
override fun getText(): String = RsBundle.message("intention.name.add.tokio.main")

}
override fun invoke(project: Project, editor: Editor?, element: RsFunction) {
if (!element.isAsync) {
val anchor = element.unsafe ?: element.externAbi ?: element.fn
element.addBefore(RsPsiFactory(project).createAsyncKeyword(), anchor)
}

val anchor = element.outerAttrList.firstOrNull() ?: element.firstKeyword
element.addOuterAttribute(Attribute("tokio::main"), anchor)

if (!element.isIntentionPreviewElement && !hasTokio) {
installTokio(project)
if (!element.isIntentionPreviewElement) {
element.containingCrate.addCargoDependency("tokio", "1.0.0", REQUIRED_TOKIO_FEATURES)
element.containingCrate.cargoTarget?.pkg?.getPackageCargoTomlFile(project)?.document?.let {
FileDocumentManager.getInstance().saveDocument(it)
}

project.cargoProjects.refreshAllProjects()
}
}
private fun installTokio(project: Project) {
val cargo = project.toolchain?.cargo() ?: return
object : Task.Backgroundable(project, RsBundle.message("progress.title.adding.dependency", "tokio")) {
override fun shouldStartInBackground(): Boolean = true
override fun run(indicator: ProgressIndicator) {
cargo.addDependency(project, "tokio", listOf("full"))
}
override fun onSuccess() {
project.cargoProjects.refreshAllProjects()
}
}.queue()

companion object {
private val REQUIRED_TOKIO_FEATURES = listOf("rt", "rt-multi-thread", "macros")
}
}
83 changes: 83 additions & 0 deletions src/main/kotlin/org/rust/toml/CrateExt.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Use of this source code is governed by the MIT license that can be
* found in the LICENSE file.
*/

package org.rust.toml

import org.rust.lang.core.crate.Crate
import org.rust.openapiext.checkWriteAccessAllowed
import org.toml.lang.psi.*

/*
* Use of this source code is governed by the MIT license that can be
* found in the LICENSE file.
*/

/**
* Adds dependency [name] with version [version] to the corresponding `[dependencies]` section
* in the corresponding `Cargo.toml` file if the dependency doesn't exist already. If it does,
* update the features of the dependency with [features] if required.
*
* For example, if [name] = "tokio", [version] = "1.0.0" and [features] = "full", it inserts
* ```
* [dependencies]
* tokio = { version = "1.0.0", features = ["full"] }
* ```
*/
fun Crate.addCargoDependency(name: String, version: String, features: List<String> = emptyList()) {
checkWriteAccessAllowed()

val cargoToml = cargoTarget?.pkg?.getPackageCargoTomlFile(project) ?: return
val factory = TomlPsiFactory(project)

val featuresArray = features.joinToString(prefix = "[", separator = ", ", postfix = "]") { "\"$it\"" }

when (val existingDependency = cargoToml.findDependencyElement(name)) {
is TomlKeyValueOwner -> {
updateDependencyFeatures(factory, existingDependency, features)
}
is TomlLiteral -> {
val newVersion = existingDependency.stringValue ?: version
val newEntry = factory.createInlineTable("""version = "$newVersion", features = $featuresArray""")
existingDependency.replace(newEntry)
}
else -> {
val existingDependencies = cargoToml.tableList.find {
it.header.key?.stringValue == "dependencies"
}
val dependencies = existingDependencies ?: run {
val newDependenciesTable = factory.createTable("dependencies")
cargoToml.add(factory.createWhitespace("\n"))
cargoToml.add(newDependenciesTable) as TomlTable
}
val newDependencyKeyValue = if (features.isEmpty()) {
factory.createKeyValue(name, version)
} else {
factory.createKeyValue(name, """{ version = "$version", features = $featuresArray }""")
}

dependencies.add(factory.createWhitespace("\n"))
dependencies.add(newDependencyKeyValue)
}
}
}

private fun updateDependencyFeatures(factory: TomlPsiFactory, table: TomlKeyValueOwner, features: List<String>) {
val featuresEntry = table.entries.find { entry -> entry.key.stringValue == "features" }
if (featuresEntry == null) {
val featuresArray = features.joinToString(prefix = "[", separator = ", ", postfix = "]") { "\"$it\"" }
val newEntry = factory.createKeyValue("features", featuresArray)
val newTable = (table.entries + listOf(newEntry)).joinToString(separator=", ") {
"""${it.key.text} = ${it.value?.text}"""
}
table.replace(factory.createInlineTable(newTable))
} else {
val existingFeatures = (featuresEntry.value as? TomlArray)?.elements
?.mapNotNull { value -> value.stringValue }
?: emptyList()
val newFeatures = (existingFeatures + features).distinct()
val newFeaturesArray = newFeatures.joinToString(prefix = "[", separator = ", ", postfix = "]") { "\"$it\"" }
featuresEntry.replace(factory.createKeyValue("features", newFeaturesArray))
}
}
79 changes: 72 additions & 7 deletions src/main/kotlin/org/rust/toml/Util.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ import org.rust.cargo.CargoConstants
import org.rust.cargo.project.workspace.CargoWorkspace
import org.rust.ide.notifications.showBalloonWithoutProject
import org.rust.lang.core.completion.getElementOfType
import org.rust.lang.core.psi.ext.ancestorOrSelf
import org.rust.lang.core.psi.ext.elementType
import org.rust.lang.core.psi.ext.findCargoPackage
import org.rust.lang.core.psi.ext.isAncestorOf
import org.rust.lang.core.psi.ext.*
import org.rust.openapiext.toPsiFile
import org.toml.lang.psi.*
import org.toml.lang.psi.ext.TomlLiteralKind
Expand Down Expand Up @@ -136,15 +133,15 @@ fun getClosestKeyValueAncestor(position: PsiElement): TomlKeyValue? {
}
}

fun CargoWorkspace.Package.getPackageTomlFile(project: Project): TomlFile? {
fun CargoWorkspace.Package.getPackageCargoTomlFile(project: Project): TomlFile? {
return contentRoot?.findChild(CargoConstants.MANIFEST_FILE)
?.toPsiFile(project)
as? TomlFile
}

fun PsiElement.findCargoPackageForCargoToml(): CargoWorkspace.Package? {
val containingFile = containingFile.originalFile
return containingFile.findCargoPackage()?.takeIf { it.getPackageTomlFile(containingFile.project) == containingFile }
return containingFile.findCargoPackage()?.takeIf { it.getPackageCargoTomlFile(containingFile.project) == containingFile }
}

private fun CargoWorkspace.Package.findDependencyByPackageName(pkgName: String): CargoWorkspace.Package? =
Expand All @@ -153,7 +150,7 @@ private fun CargoWorkspace.Package.findDependencyByPackageName(pkgName: String):
fun findDependencyTomlFile(element: TomlElement, depName: String): TomlFile? =
element.findCargoPackageForCargoToml()
?.findDependencyByPackageName(depName)
?.getPackageTomlFile(element.project)
?.getPackageCargoTomlFile(element.project)

/**
* Consider `Cargo.toml`:
Expand Down Expand Up @@ -185,8 +182,76 @@ val TomlValue.containingDependencyKey: TomlKeySegment?
}
}

val TomlKey.stringValue: String
get() {
return segments.map { it.name }.joinToString(".")
}

val TomlValue.stringValue: String?
get() {
val kind = (this as? TomlLiteral)?.kind
return (kind as? TomlLiteralKind.String)?.value
}

val TomlFile.tableList: List<TomlTable> get() = childrenOfType<TomlTable>()


fun TomlFile.findDependencyElement(dependencyName: String): TomlElement? {
// Check for cases like:
// [dependencies.foo]
// version = "1.0.0"
val existingInlinedDependency = tableList.find { table ->
table.header.key?.stringValue == "dependencies.$dependencyName"
}
if (existingInlinedDependency != null) {
return existingInlinedDependency
}

// Check for cases like:
// [dependencies.xxx]
// name = "foo"
// version = "1.0.0"
val existingInlinedDependencyWithName = tableList.find { table ->
val headerSegments = table.header.key?.segments ?: return@find false
headerSegments.size > 1
&& headerSegments.firstOrNull()?.name == "dependencies"
&& table.entries.any {
entry -> entry.key.stringValue == "name" && entry.value?.stringValue == dependencyName
}
}
if (existingInlinedDependencyWithName != null) {
return existingInlinedDependencyWithName
}

val dependenciesTable = tableList.find {
it.header.key?.segments?.singleOrNull()?.name == "dependencies"
} ?: return null

// Check for cases like:
// [dependencies]
// foo = "1.0.0"
// or
// [dependencies]
// foo = { version = "1.0.0" }
val existingEntry = dependenciesTable.entries.find { entry ->
entry.key.stringValue == dependencyName
&& (entry.value as? TomlInlineTable)?.entries?.find { it.key.stringValue == "name" } == null
}
if (existingEntry != null) {
return existingEntry.value
}

// Check for cases like:
// [dependencies]
// xxx = { name = "foo", version = "1.0.0" }
val existingEntryWithName = dependenciesTable.entries.find { entry ->
(entry.value as? TomlInlineTable)?.entries?.any {
it.key.stringValue == "name" && it.value?.stringValue == dependencyName
} == true
}
if (existingEntryWithName != null) {
return existingEntryWithName.value
}
return null
}

Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import com.intellij.util.ProcessingContext
import org.rust.cargo.project.workspace.PackageOrigin
import org.rust.toml.StringLiteralInsertionHandler
import org.rust.toml.findCargoPackageForCargoToml
import org.rust.toml.getPackageTomlFile
import org.rust.toml.getPackageCargoTomlFile
import org.rust.toml.resolve.allFeatures
import org.toml.lang.psi.TomlFile
import org.toml.lang.psi.impl.TomlKeyValueImpl
Expand Down Expand Up @@ -45,7 +45,7 @@ class CargoTomlFeatureDependencyCompletionProvider : CompletionProvider<Completi
for (dep in pkg.dependencies) {
if (dep.pkg.origin == PackageOrigin.STDLIB) continue
// TODO avoid AST loading?
for (feature in dep.pkg.getPackageTomlFile(tomlFile.project)?.allFeatures().orEmpty()) {
for (feature in dep.pkg.getPackageCargoTomlFile(tomlFile.project)?.allFeatures().orEmpty()) {
result.addElement(
LookupElementBuilder
.createWithSmartPointer("${dep.pkg.name}/${feature.text}", feature)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.rust.lang.core.psi.ext.RsElement
import org.rust.lang.core.psi.ext.ancestorOrSelf
import org.rust.lang.core.psi.ext.containingCargoPackage
import org.rust.lang.core.psi.ext.elementType
import org.rust.toml.getPackageTomlFile
import org.rust.toml.getPackageCargoTomlFile
import org.rust.toml.resolve.allFeatures
import org.toml.lang.psi.TomlKeySegment

Expand All @@ -39,7 +39,7 @@ import org.toml.lang.psi.TomlKeySegment
object RsCfgFeatureCompletionProvider : RsCompletionProvider() {
override fun addCompletions(parameters: CompletionParameters, context: ProcessingContext, result: CompletionResultSet) {
val pkg = parameters.position.ancestorOrSelf<RsElement>()?.containingCargoPackage ?: return
val pkgToml = pkg.getPackageTomlFile(parameters.originalFile.project) ?: return
val pkgToml = pkg.getPackageCargoTomlFile(parameters.originalFile.project) ?: return

for (feature in pkgToml.allFeatures()) {
result.addElement(rustLookupElementForFeature(feature))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@ package org.rust.toml.resolve

import com.intellij.psi.PsiElementResolveResult
import com.intellij.psi.ResolveResult
import org.rust.lang.core.psi.ext.childrenOfType
import org.rust.toml.getValueWithKey
import org.rust.toml.isDependencyListHeader
import org.rust.toml.isFeatureListHeader
import org.rust.toml.isSpecificDependencyTableHeader
import org.rust.toml.*
import org.toml.lang.psi.*
import org.toml.lang.psi.ext.TomlLiteralKind
import org.toml.lang.psi.ext.kind
Expand All @@ -25,7 +21,7 @@ fun TomlFile.resolveFeature(featureName: String, depOnly: Boolean = false): Arra

fun TomlFile.allFeatures(depOnly: Boolean = false): Sequence<TomlKeySegment> {
val explicitFeatures = hashSetOf<String>()
return childrenOfType<TomlTable>()
return tableList
.asSequence()
.flatMap { table ->
val header = table.header
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import org.rust.lang.core.psi.RsLitExpr
import org.rust.lang.core.psi.RsLiteralKind
import org.rust.lang.core.psi.ext.containingCargoPackage
import org.rust.lang.core.psi.kind
import org.rust.toml.getPackageTomlFile
import org.rust.toml.getPackageCargoTomlFile

/**
* Consider "main.rs":
Expand All @@ -31,7 +31,7 @@ class RsCfgFeatureReferenceProvider : PsiReferenceProvider() {
private class RsCfgFeatureReferenceReference(element: RsLitExpr) : PsiPolyVariantReferenceBase<RsLitExpr>(element) {
override fun multiResolve(incompleteCode: Boolean): Array<ResolveResult> {
val literalValue = (element.kind as? RsLiteralKind.String)?.value ?: return ResolveResult.EMPTY_ARRAY
val toml = element.containingCargoPackage?.getPackageTomlFile(element.project) ?: return ResolveResult.EMPTY_ARRAY
val toml = element.containingCargoPackage?.getPackageCargoTomlFile(element.project) ?: return ResolveResult.EMPTY_ARRAY
return toml.resolveFeature(literalValue)
}
}
1 change: 0 additions & 1 deletion src/main/resources/messages/RsBundle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -1419,7 +1419,6 @@ inspection.message.expected.trait.bound.found.impl.trait.type=Expected trait bou
inspection.message.invalid.dyn.keyword=Invalid `dyn` keyword
inspection.duplicated.key.display.name=Duplicated key
intention.name.add.tokio.main=Add `#[tokio::main]`
progress.title.adding.dependency=Adding {0} to dependencies
intention.name.install.tokio.and.add.main=Add tokio to dependencies and add `#[tokio::main]`
dbg.usage=#[dbg] usage
const.generics.defaults=const generics defaults
Expand Down

0 comments on commit d024d4d

Please sign in to comment.