Skip to content

Commit

Permalink
INT: add support for generic types to UnElideLifetimesIntention
Browse files Browse the repository at this point in the history
  • Loading branch information
actions-user authored and Kobzol committed Sep 17, 2020
1 parent 961fb7d commit 88241a7
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 92 deletions.
138 changes: 109 additions & 29 deletions src/main/kotlin/org/rust/ide/intentions/UnElideLifetimesIntention.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ import com.intellij.openapi.project.Project
import com.intellij.psi.PsiElement
import org.rust.lang.core.psi.*
import org.rust.lang.core.psi.ext.*
import org.rust.lang.core.types.infer.hasReEarlyBounds
import org.rust.lang.core.types.regions.Region
import org.rust.lang.core.types.ty.Ty
import org.rust.lang.core.types.ty.TyAdt
import org.rust.lang.core.types.type

class UnElideLifetimesIntention : RsElementBaseIntentionAction<RsFunction>() {
override fun getText() = "Un-elide lifetimes"
Expand All @@ -19,34 +24,51 @@ class UnElideLifetimesIntention : RsElementBaseIntentionAction<RsFunction>() {
if (element is RsDocCommentImpl) return null
val fn = element.ancestorOrSelf<RsFunction>(stopAt = RsBlock::class.java) ?: return null

if ((fn.retType?.typeReference as? RsRefLikeType)?.lifetime != null) return null
val ctx = getLifetimeContext(fn)
val outputLifetimes = ctx.output?.lifetimes
if (outputLifetimes != null) {
if (outputLifetimes.any { it != null } || outputLifetimes.size > 1) return null
}

val args = fn.allRefArgs
val refArgs = ctx.inputs + listOfNotNull(ctx.self)
if (refArgs.isEmpty() || refArgs.any { it.lifetimes.any { lifetime -> lifetime != null } }) return null

if (args.isEmpty() || args.any { it.lifetime != null }) return null
return fn
}

override fun invoke(project: Project, editor: Editor, ctx: RsFunction) {
ctx.allRefArgs.asSequence().zip(nameGenerator).forEach {
it.first.replace(createParam(project, it.first, it.second))
val (inputs, output, self) = getLifetimeContext(ctx)

val addedLifetimes = mutableListOf<String>()
val generator = nameGenerator.iterator()
(listOfNotNull(self) + inputs).forEach { ref ->
val names = ref.lifetimes.map { generator.next() }
addLifetimeParameter(ref, names)
addedLifetimes.addAll(names)
}

// generic params
val genericParams = RsPsiFactory(project).createTypeParameterList(
ctx.allRefArgs.mapNotNull { it.lifetime?.text } + ctx.typeParameters.map { it.text }
addedLifetimes + ctx.typeParameters.map { it.text }
)
ctx.typeParameterList?.replace(genericParams) ?: ctx.addAfter(genericParams, ctx.identifier)

// return type
val retType = ctx.retType?.typeReference?.skipParens() as? RsRefLikeType ?: return
if (output == null) return

if ((ctx.selfParameter != null) || (ctx.allRefArgs.drop(1).none())) {
retType.replace(createRefType(project, retType, ctx.allRefArgs.first().lifetime!!.text))
if (self != null || inputs.size == 1) {
addLifetimeParameter(output, listOf(addedLifetimes.first()))
} else {
val lifeTime = (retType.replace(createRefType(project, retType, "'unknown"))
as RsRefLikeType).lifetime ?: return
editor.selectionModel.setSelection(lifeTime.startOffset + 1, lifeTime.endOffset)
val unknownLifetime = "'_"
val element = addLifetimeParameter(output, listOf(unknownLifetime))
element.accept(object : RsRecursiveVisitor() {
override fun visitLifetime(o: RsLifetime) {
if (o.quoteIdentifier.text == unknownLifetime) {
val start = o.startOffset + 1
editor.caretModel.moveToOffset(start)
editor.selectionModel.setSelection(start, o.endOffset)
}
}
})
}
}

Expand All @@ -56,27 +78,85 @@ class UnElideLifetimesIntention : RsElementBaseIntentionAction<RsFunction>() {
val index = it / abcSize
return@map if (index == 0) "'$letter" else "'$letter$index"
}
}

private fun createRefType(project: Project, origin: RsRefLikeType, lifeTimeName: String): RsRefLikeType =
RsPsiFactory(project).createType(origin.text.replaceFirst("&", "&$lifeTimeName ")) as RsRefLikeType

private fun createParam(project: Project, origin: PsiElement, lifeTimeName: String): PsiElement =
RsPsiFactory(project).createMethodParam(origin.text.replaceFirst("&", "&$lifeTimeName "))
private sealed class PotentialLifetimeRef(val element: RsElement) {
data class Self(val self: RsSelfParameter) : PotentialLifetimeRef(self)
data class RefLike(val ref: RsRefLikeType) : PotentialLifetimeRef(ref)
data class BaseType(val baseType: RsBaseType, val type: Ty) : PotentialLifetimeRef(baseType) {
val typeLifetimes: List<Region>
get() = when (val type = baseType.type) {
is TyAdt -> type.regionArguments.filter { it.hasReEarlyBounds }
else -> emptyList()
}
}

private val RsFunction.allRefArgs: List<PsiElement> get() {
val selfAfg: List<PsiElement> = listOfNotNull(selfParameter)
val params: List<PsiElement> = valueParameters
.filter { param ->
val type = param.typeReference?.skipParens()
type is RsRefLikeType && type.isRef
val lifetimes: List<RsLifetime?>
get() = when (this) {
is Self -> listOf(self.lifetime)
is RefLike -> listOf(ref.lifetime)
is BaseType -> {
val lifetimes = typeLifetimes
val actualLifetimes = baseType.path?.typeArgumentList?.lifetimeList
lifetimes.indices.map { actualLifetimes?.getOrNull(it) }
}
return selfAfg + params
}
}

private fun isPotentialLifetimeAdt(ref: RsTypeReference): Boolean {
return when (val type = ref.type) {
is TyAdt -> type.regionArguments.all { it.hasReEarlyBounds }
else -> false
}
}

private val PsiElement.lifetime: RsLifetime? get() =
when (this) {
is RsSelfParameter -> lifetime
is RsValueParameter -> (typeReference?.skipParens() as? RsRefLikeType)?.lifetime
private fun parsePotentialLifetimeType(ref: RsTypeReference): PotentialLifetimeRef? {
return when {
ref is RsRefLikeType -> PotentialLifetimeRef.RefLike(ref)
ref is RsBaseType && isPotentialLifetimeAdt(ref) -> PotentialLifetimeRef.BaseType(ref, ref.type)
else -> null
}
}

private data class LifetimeContext(
val inputs: List<PotentialLifetimeRef>,
val output: PotentialLifetimeRef?,
val self: PotentialLifetimeRef?
)

private fun getLifetimeContext(fn: RsFunction): LifetimeContext {
val inputArgs = fn.valueParameters.mapNotNull { elem -> elem.typeReference?.let { parsePotentialLifetimeType(it) } }
val retType = fn.retType?.typeReference?.let { parsePotentialLifetimeType(it) }

return LifetimeContext(inputArgs, retType, fn.selfParameter?.let { PotentialLifetimeRef.Self(it) })
}

private fun addLifetimeParameter(ref: PotentialLifetimeRef, names: List<String>): PsiElement {
val factory = RsPsiFactory(ref.element.project)
return when (ref) {
is PotentialLifetimeRef.Self -> {
val elem = ref.element
elem.replace(factory.createMethodParam(elem.text.replaceFirst("&", "&${names[0]} ")))
}
is PotentialLifetimeRef.RefLike -> {
val elem = ref.element
val typeRef = factory.createType(elem.text.replaceFirst("&", "&${names[0]} "))
elem.replace(typeRef)
}
is PotentialLifetimeRef.BaseType -> {
val elem = ref.baseType
val typeList = names.toMutableList()

val typeArguments = elem.path?.typeArgumentList
if (typeArguments != null) {
typeList += typeArguments.typeReferenceList.map { it.text }
typeList += typeArguments.assocTypeBindingList.map { it.text }
}

val baseTypeName = elem.name
val types = factory.createTypeParameterList(typeList)
val replacement = factory.createType("$baseTypeName${types.text}")
elem.replace(replacement)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,106 +6,125 @@
package org.rust.ide.intentions

class UnElideLifetimesIntentionTest : RsIntentionTestBase(UnElideLifetimesIntention::class) {
fun `test unavailable`() = doUnavailableTest(
"""
fun `test unavailable without references`() = doUnavailableTest("""
fn bar/*caret*/(x: i32) -> i32 {}
"""
)
""")

fun `test unavailable block body`() = doUnavailableTest(
"""
fun `test unavailable in block body`() = doUnavailableTest("""
fn bar(x: &i32) {/*caret*/}
"""
)
""")

fun `test unavailable doc comment`() = doUnavailableTest(
"""
fun `test unavailable in doc comment`() = doUnavailableTest("""
/// ```
/// /*caret*/
/// ```
fn bar(x: &i32) {}
"""
)
""")

fun `test unavailable no args`() = doUnavailableTest(
"""
fun `test unavailable without args`() = doUnavailableTest("""
fn bar(/*caret*/) {}
"""
)
""")

fun `test unavailable un elided`() = doUnavailableTest(
"""
fun `test unavailable with explicit lifetime`() = doUnavailableTest("""
fn bar<'a>(x: &'a /*caret*/ i32) {}
"""
)
""")

fun `test simple`() = doAvailableTest(
"""
fun `test simple`() = doAvailableTest("""
fn foo(p: &/*caret*/ i32) -> & i32 { p }
"""
,
"""
""", """
fn foo<'a>(p: &/*caret*/'a i32) -> &'a i32 { p }
"""
)
""")

fun `test generic type`() = doAvailableTest(
"""
fun `test mut ref`() = doAvailableTest("""
fn foo(p: &/*caret*/mut i32) -> & i32 { p }
""", """
fn foo<'a>(p: &/*caret*/'a mut i32) -> &'a i32 { p }
""")

fun `test nested ref`() = doAvailableTest("""
fn foo(p: &&/*caret*/ i32) -> & i32 { unimplemented!() }
""", """
fn foo<'a>(p: &'a &/*caret*/i32) -> &'a i32 { unimplemented!() }
""")

fun `test generic type`() = doAvailableTest("""
fn foo<T>(p1:/*caret*/ &i32, p2: T) -> & i32 { p }
"""
,
"""
""", """
fn foo<'a, T>(p1:/*caret*/ &'a i32, p2: T) -> &'a i32 { p }
"""
)
""")

fun `test lifetime type as parameter`() = doAvailableTest("""
struct S<'a> { x: &'a u32 }
fn make_s(x:/*caret*/ S) { unimplemented!() }
""", """
struct S<'a> { x: &'a u32 }
fn make_s<'a>(x:/*caret*/ S<'a>) { unimplemented!() }
""")

fun `test lifetime type as return value`() = doAvailableTest("""
struct S<'a> { x: &'a i32 }
fn make_s(x:/*caret*/ &i32) -> S { unimplemented!() }
""", """
struct S<'a> { x: &'a i32 }
fn make_s<'a>(x:/*caret*/ &'a i32) -> S<'a> { unimplemented!() }
""")

fun `test unknown`() = doAvailableTest(
"""
fun `test struct parameter with multiple lifetimes`() = doAvailableTest("""
struct S<'a, 'b> { x: &'a u32, y: &'b u32 }
fn make_s(x:/*caret*/ S) { unimplemented!() }
""", """
struct S<'a, 'b> { x: &'a u32, y: &'b u32 }
fn make_s<'a, 'b>(x:/*caret*/ S<'a, 'b>) { unimplemented!() }
""")

fun `test struct return type with multiple lifetimes`() = doUnavailableTest("""
struct S<'a, 'b> { x: &'a u32, y: &'b u32 }
fn make_s(x:/*caret*/ &i32) -> S { unimplemented!() }
""")

fun `test lifetime type complex struct`() = doAvailableTest("""
struct S<'a, T, X> { x: &'a T, y: X }
fn make_s<X>(x:/*caret*/ S<u32, X>) -> S<u32, X> { unimplemented!() }
""", """
struct S<'a, T, X> { x: &'a T, y: X }
fn make_s<'a, X>(x:/*caret*/ S<'a, u32, X>) -> S<'a, u32, X> { unimplemented!() }
""")

fun `test unknown`() = doAvailableTest("""
fn foo(p1: &i32,/*caret*/ p2: &i32) -> &i32 { p2 }
"""
,
"""
fn foo<'a, 'b>(p1: &'a i32,/*caret*/ p2: &'b i32) -> &'<selection>unknown</selection> i32 { p2 }
"""
)

fun `test method decl`() = doAvailableTest(
"""
""", """
fn foo<'a, 'b>(p1: &'a i32, p2: &'b i32) -> &'<selection>_</selection> i32 { p2 }
""")

fun `test method decl`() = doAvailableTest("""
trait Foo {
fn /*caret*/bar(&self, x: &i32, y: &i32, x: i32) -> &i32;
fn /*caret*/bar(&self, x: &i32, y: &i32, z: i32) -> &i32;
}
"""
,
"""
""", """
trait Foo {
fn /*caret*/bar<'a, 'b, 'c>(&'a self, x: &'b i32, y: &'c i32, x: i32) -> &'a i32;
fn /*caret*/bar<'a, 'b, 'c>(&'a self, x: &'b i32, y: &'c i32, z: i32) -> &'a i32;
}
"""
)
""")

fun `test method impl`() = doAvailableTest(
"""
fun `test method impl`() = doAvailableTest("""
trait Foo {
fn bar(&self, x: &i32, y: &i32, x: i32) -> &i32;
fn bar(&self, x: &i32, y: &i32, z: i32) -> &i32;
}
struct S {}
impl Foo for S {
fn /*caret*/bar(&self, x: &i32, y: &i32, x: i32) -> &i32 {
fn /*caret*/bar(&self, x: &i32, y: &i32, z: i32) -> &i32 {
unimplemented!()
}
}
"""
,
"""
""", """
trait Foo {
fn bar(&self, x: &i32, y: &i32, x: i32) -> &i32;
fn bar(&self, x: &i32, y: &i32, z: i32) -> &i32;
}
struct S {}
impl Foo for S {
fn /*caret*/bar<'a, 'b, 'c>(&'a self, x: &'b i32, y: &'c i32, x: i32) -> &'a i32 {
fn /*caret*/bar<'a, 'b, 'c>(&'a self, x: &'b i32, y: &'c i32, z: i32) -> &'a i32 {
unimplemented!()
}
}
"""
)
""")
}

0 comments on commit 88241a7

Please sign in to comment.