Skip to content

Commit

Permalink
Merge #6471
Browse files Browse the repository at this point in the history
6471: INT: add support for associated functions in CreateFunctionIntention r=vlad20012 a=Kobzol

This PR enables `CreateFunctionIntention` for associated methods.
![associated-method](https://user-images.githubusercontent.com/4539057/100888964-f12ecb00-34b6-11eb-9aaf-c70b99faa9db.gif)

changelog: Enable `CreateFunctionIntention` also for associated methods.


Co-authored-by: Jakub Beránek <berykubik@gmail.com>
  • Loading branch information
bors[bot] and Kobzol committed Jan 17, 2021
2 parents ac720c5 + c02504d commit 004e240
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 15 deletions.
82 changes: 67 additions & 15 deletions src/main/kotlin/org/rust/ide/intentions/CreateFunctionIntention.kt
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,24 @@ import org.rust.openapiext.createSmartPointer
class CreateFunctionIntention : RsElementBaseIntentionAction<CreateFunctionIntention.Context>() {
override fun getFamilyName() = "Create function"

sealed class FunctionInsertionTarget {
abstract val module: RsMod

class Module(val target: RsMod): FunctionInsertionTarget() {
override val module: RsMod = target
}
class Item(val item: RsStructOrEnumItemElement): FunctionInsertionTarget() {
override val module: RsMod = item.containingMod
}
}

sealed class Context(val name: String, val callElement: PsiElement) {
abstract val visibility: String
abstract val arguments: RsValueArgumentList
abstract val returnType: Ty?
open val implItem: RsImplItem? = null

class Function(val callExpr: RsCallExpr, name: String, val module: RsMod) : Context(name, callExpr) {
open class Function(val callExpr: RsCallExpr, name: String, val module: RsMod) : Context(name, callExpr) {
override val visibility: String = when {
callExpr.containingCrate != module.containingCrate -> "pub "
callExpr.containingMod != module -> "pub(crate) "
Expand All @@ -39,6 +51,16 @@ class CreateFunctionIntention : RsElementBaseIntentionAction<CreateFunctionInten
override val returnType: Ty? = callExpr.expectedType
}

class AssociatedFunction(
callExpr: RsCallExpr,
name: String,
module: RsMod,
val item: RsStructOrEnumItemElement
) : Function(callExpr, name, module) {
override val implItem: RsImplItem?
get() = super.implItem
}

class Method(val methodCall: RsMethodCall, name: String, val item: RsStructOrEnumItemElement)
: Context(name, methodCall) {
override val visibility: String
Expand All @@ -60,13 +82,20 @@ class CreateFunctionIntention : RsElementBaseIntentionAction<CreateFunctionInten
val path = element.parentOfType<RsPath>()
val functionCall = path?.parentOfType<RsCallExpr>()
if (functionCall != null) {
if (path.resolveStatus != PathResolveStatus.UNRESOLVED) return null
if (!functionCall.expr.isAncestorOf(path)) return null
if (path.resolveStatus != PathResolveStatus.UNRESOLVED) return null

val module = getTargetModuleForFunction(path) ?: return null
val target = getTargetItemForFunction(path) ?: return null
val name = path.referenceName ?: return null
text = "Create function `$name`"
return Context.Function(functionCall, name, module)

return if (target is FunctionInsertionTarget.Item) {
text = "Create associated function `${target.item.name}::$name`"
Context.AssociatedFunction(functionCall, name, target.module, target.item)
}
else {
text = "Create function `$name`"
Context.Function(functionCall, name, target.module)
}
}
val methodCall = element.parentOfType<RsMethodCall>()
if (methodCall != null) {
Expand Down Expand Up @@ -118,19 +147,26 @@ class CreateFunctionIntention : RsElementBaseIntentionAction<CreateFunctionInten
return factory.tryCreateFunction("$visibility fn $functionName$genericParams($paramsText)$returnType $whereClause {\n unimplemented!()\n}")
}

private fun getTargetModuleForFunction(path: RsPath): RsMod? {
private fun getTargetItemForFunction(path: RsPath): FunctionInsertionTarget? {
if (path.qualifier != null) {
val mod = path.qualifier?.reference?.resolve() as? RsMod
if (mod?.containingCargoPackage?.origin != PackageOrigin.WORKSPACE) return null
if (!isUnitTestMode && !mod.isWritable) return null
return mod
val item = path.qualifier?.reference?.resolve() as? RsQualifiedNamedElement
if (item?.containingCargoPackage?.origin != PackageOrigin.WORKSPACE) return null
if (!isUnitTestMode && !item.isWritable) return null

return when (item) {
is RsMod -> FunctionInsertionTarget.Module(item)
is RsStructOrEnumItemElement -> FunctionInsertionTarget.Item(item)
else -> null
}
}
return path.containingMod
return FunctionInsertionTarget.Module(path.containingMod)
}

private data class CallableConfig(val parameters: List<String>,
val returnType: Ty,
val genericConstraints: GenericConstraints)
private data class CallableConfig(
val parameters: List<String>,
val returnType: Ty,
val genericConstraints: GenericConstraints
)

private fun getCallableConfig(ctx: Context): CallableConfig {
val callExpr = ctx.callElement
Expand All @@ -145,7 +181,7 @@ class CreateFunctionIntention : RsElementBaseIntentionAction<CreateFunctionInten
.filterByTypes(arguments.exprList.map { it.type }.plus(returnType))

val filteredConstraints = if (ctx is Context.Method) {
val params = ctx.callElement.parentOfType<RsImplItem>()?.typeParameters.orEmpty()
val params = callExpr.parentOfType<RsImplItem>()?.typeParameters.orEmpty()
genericConstraints.withoutTypes(params)
} else genericConstraints

Expand All @@ -156,11 +192,27 @@ class CreateFunctionIntention : RsElementBaseIntentionAction<CreateFunctionInten
val sourceFunction = ctx.callElement.parentOfType<RsFunction>() ?: return null

return when (ctx) {
is Context.AssociatedFunction -> insertAssociatedFunction(ctx.item, function)
is Context.Function -> insertFunction(ctx.module, sourceFunction, function)
is Context.Method -> insertMethod(ctx.item, sourceFunction, function)
}
}

private fun insertAssociatedFunction(
item: RsStructOrEnumItemElement,
function: RsFunction
): RsFunction? {
val psiFactory = RsPsiFactory(item.project)
val name = item.name ?: return null

val newImpl = psiFactory.createInherentImplItem(name, item.typeParameterList, item.whereClause)
val impl = item.parent.addAfter(newImpl, item) as RsImplItem

return impl.members?.let {
it.addBefore(function, it.rbrace) as RsFunction
}
}

private fun insertFunction(
targetModule: RsMod,
sourceFunction: RsFunction,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ class CreateFunctionIntentionTest : RsIntentionTestBase(CreateFunctionIntention:
}
""")

fun `test unavailable on trait associated function`() = doUnavailableTest("""
trait Trait {}
fn foo() {
Trait::baz/*caret*/();
}
""")

fun `test create function`() = doAvailableTest("""
fn main() {
/*caret*/foo();
Expand Down Expand Up @@ -653,4 +661,65 @@ class CreateFunctionIntentionTest : RsIntentionTestBase(CreateFunctionIntention:
}
}
""")

fun `test create associated function for struct`() = doAvailableTest("""
struct S;
fn foo() {
S::bar/*caret*/(1, 2);
}
""", """
struct S;
impl S {
fn bar(p0: i32, p1: i32) {
unimplemented!()
}
}
fn foo() {
S::bar(1, 2);
}
""")

fun `test create associated function for enum`() = doAvailableTest("""
enum S {
V1
}
fn foo() {
S::bar/*caret*/(1, 2);
}
""", """
enum S {
V1
}
impl S {
fn bar(p0: i32, p1: i32) {
unimplemented!()
}
}
fn foo() {
S::bar(1, 2);
}
""")

fun `test create associated function for generic struct`() = doAvailableTest("""
struct S<T>(T);
fn foo() {
S::<u32>::bar/*caret*/(1, 2);
}
""", """
struct S<T>(T);
impl<T> S<T> {
fn bar(p0: i32, p1: i32) {
unimplemented!()
}
}
fn foo() {
S::<u32>::bar(1, 2);
}
""")
}

0 comments on commit 004e240

Please sign in to comment.