-
Notifications
You must be signed in to change notification settings - Fork 380
/
impl.kt
121 lines (102 loc) · 5.11 KB
/
impl.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
/*
* Use of this source code is governed by the MIT license that can be
* found in the LICENSE file.
*/
package org.rust.ide.refactoring.implementMembers
import com.intellij.codeInsight.hint.HintManager
import com.intellij.openapi.application.runWriteAction
import com.intellij.openapi.editor.Editor
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiWhiteSpace
import org.rust.ide.inspections.import.RsImportHelper.importTypeReferencesFromElements
import org.rust.lang.core.macros.expandedFromRecursively
import org.rust.lang.core.psi.*
import org.rust.lang.core.psi.ext.*
import org.rust.lang.core.types.BoundElement
import org.rust.openapiext.checkReadAccessAllowed
import org.rust.openapiext.checkWriteAccessAllowed
import org.rust.openapiext.checkWriteAccessNotAllowed
fun generateTraitMembers(impl: RsImplItem, editor: Editor?) {
checkWriteAccessNotAllowed()
val (implInfo, trait) = findMembersToImplement(impl) ?: run {
if (editor != null) {
HintManager.getInstance().showErrorHint(editor, "No members to implement have been found")
}
return
}
val chosen = showTraitMemberChooser(implInfo, impl.project)
if (chosen.isEmpty()) return
runWriteAction {
// Non-null was checked by `findMembersToImplement`.
insertNewTraitMembers(chosen, impl.members!!, trait)
}
}
private fun findMembersToImplement(impl: RsImplItem): Pair<TraitImplementationInfo, BoundElement<RsTraitItem>>? {
checkReadAccessAllowed()
val trait = impl.traitRef?.resolveToBoundTrait() ?: return null
val implInfo = TraitImplementationInfo.create(trait.element, impl) ?: return null
if (implInfo.declared.isEmpty()) return null
return implInfo to trait
}
private fun insertNewTraitMembers(
selected: Collection<RsAbstractable>,
existingMembers: RsMembers,
trait: BoundElement<RsTraitItem>
) {
checkWriteAccessAllowed()
if (selected.isEmpty()) return
val templateImpl = RsPsiFactory(existingMembers.project).createMembers(selected, trait.subst)
val traitMembers = trait.element.expandedMembers
val newMembers = templateImpl.childrenOfType<RsAbstractable>()
// [1] First, check if the order of the existingMembers already implemented
// matches the order of existingMembers in the trait declaration.
val existingMembersWithPosInTrait = existingMembers.expandedMembers.map { existingMember ->
Pair(existingMember, traitMembers.indexOfFirst {
it.elementType == existingMember.elementType && it.name == existingMember.name
})
}.toMutableList()
val existingMembersOrder = existingMembersWithPosInTrait.map { it.second }
val areExistingMembersInTheRightOrder = existingMembersOrder == existingMembersOrder.sorted()
for ((index, newMember) in newMembers.withIndex()) {
val posInTrait = traitMembers.indexOfFirst {
it.elementType == newMember.elementType && it.name == newMember.name
}
var indexedExistingMembers = existingMembersWithPosInTrait.withIndex()
// If [1] does not hold, the first new member we will append at the end of the implementation.
// All the other ones will consequently be inserted at the right position in relation to that very first one.
if (areExistingMembersInTheRightOrder || index > 0) {
indexedExistingMembers = indexedExistingMembers.filter { it.value.second < posInTrait }
}
val anchor = indexedExistingMembers
.lastOrNull()
?.let {
val member = it.value.first
IndexedValue(it.index, member.expandedFromRecursively ?: member)
}
?: IndexedValue(-1, existingMembers.lbrace)
val addedMember = existingMembers.addAfter(newMember, anchor.value) as RsAbstractable
existingMembersWithPosInTrait.add(anchor.index + 1, Pair(addedMember, posInTrait))
// If the newly added item is a function, we add an extra line between it and each of its siblings.
val prev = addedMember.leftSiblings.find { it is RsAbstractable || it is RsMacroCall }
if (prev != null && (prev is RsFunction || addedMember is RsFunction)) {
val whitespaces = createExtraWhitespacesAroundFunction(prev, addedMember)
existingMembers.addBefore(whitespaces, addedMember)
}
val next = addedMember.rightSiblings.find { it is RsAbstractable || it is RsMacroCall }
if (next != null && (next is RsFunction || addedMember is RsFunction)) {
val whitespaces = createExtraWhitespacesAroundFunction(addedMember, next)
existingMembers.addAfter(whitespaces, addedMember)
}
}
importTypeReferencesFromElements(existingMembers, selected, trait.subst)
}
private fun createExtraWhitespacesAroundFunction(left: PsiElement, right: PsiElement): PsiElement {
val lineCount = left
.rightSiblings
.takeWhile { it != right }
.filterIsInstance<PsiWhiteSpace>()
.map { it.text.count { c -> c == '\n' } }
.sum()
val extraLineCount = Math.max(0, 2 - lineCount)
return RsPsiFactory(left.project).createWhitespace("\n".repeat(extraLineCount))
}