Skip to content

Commit

Permalink
Support completions for extension definition parameter (scala#18331)
Browse files Browse the repository at this point in the history
Extension methods are extended into normal definitions.
Because of that typed trees don't include any information about the
extension method definition parameter:
```scala
extension (x: In@@)
```
In order to add completions, we check if there is an exact path to the
untyped tree, and if not, we fall back to it. There may also be more
possible cases like that, but I can't think of any at the moment.
  • Loading branch information
rochala committed Oct 3, 2023
1 parent 38ee06e commit de4ad2b
Show file tree
Hide file tree
Showing 9 changed files with 421 additions and 209 deletions.
128 changes: 83 additions & 45 deletions compiler/src/dotty/tools/dotc/interactive/Completion.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package dotty.tools.dotc.interactive

import scala.language.unsafeNulls

import dotty.tools.dotc.ast.untpd
import dotty.tools.dotc.ast.NavigateAST
import dotty.tools.dotc.config.Printers.interactiv
import dotty.tools.dotc.core.Contexts._
import dotty.tools.dotc.core.Decorators._
Expand All @@ -25,6 +24,10 @@ import dotty.tools.dotc.util.SourcePosition

import scala.collection.mutable
import scala.util.control.NonFatal
import dotty.tools.dotc.core.ContextOps.localContext
import dotty.tools.dotc.core.Names
import dotty.tools.dotc.core.Types
import dotty.tools.dotc.core.Symbols

/**
* One of the results of a completion query.
Expand All @@ -37,18 +40,17 @@ import scala.util.control.NonFatal
*/
case class Completion(label: String, description: String, symbols: List[Symbol])

object Completion {
object Completion:

import dotty.tools.dotc.ast.tpd._

/** Get possible completions from tree at `pos`
*
* @return offset and list of symbols for possible completions
*/
def completions(pos: SourcePosition)(using Context): (Int, List[Completion]) = {
val path = Interactive.pathTo(ctx.compilationUnit.tpdTree, pos.span)
def completions(pos: SourcePosition)(using Context): (Int, List[Completion]) =
val path: List[Tree] = Interactive.pathTo(ctx.compilationUnit.tpdTree, pos.span)
computeCompletions(pos, path)(using Interactive.contextOfPath(path).withPhase(Phases.typerPhase))
}

/**
* Inspect `path` to determine what kinds of symbols should be considered.
Expand All @@ -60,10 +62,11 @@ object Completion {
*
* Otherwise, provide no completion suggestion.
*/
def completionMode(path: List[Tree], pos: SourcePosition): Mode =
path match {
case Ident(_) :: Import(_, _) :: _ => Mode.ImportOrExport
case (ref: RefTree) :: _ =>
def completionMode(path: List[untpd.Tree], pos: SourcePosition): Mode =
path match
case untpd.Ident(_) :: untpd.Import(_, _) :: _ => Mode.ImportOrExport
case untpd.Ident(_) :: (_: untpd.ImportSelector) :: _ => Mode.ImportOrExport
case (ref: untpd.RefTree) :: _ =>
if (ref.name.isTermName) Mode.Term
else if (ref.name.isTypeName) Mode.Type
else Mode.None
Expand All @@ -72,9 +75,8 @@ object Completion {
if sel.imported.span.contains(pos.span) then Mode.ImportOrExport
else Mode.None // Can't help completing the renaming

case (_: ImportOrExport) :: _ => Mode.ImportOrExport
case (_: untpd.ImportOrExport) :: _ => Mode.ImportOrExport
case _ => Mode.None
}

/** When dealing with <errors> in varios palces we check to see if they are
* due to incomplete backticks. If so, we ensure we get the full prefix
Expand All @@ -101,10 +103,13 @@ object Completion {
case (sel: untpd.ImportSelector) :: _ =>
completionPrefix(sel.imported :: Nil, pos)

case untpd.Ident(_) :: (sel: untpd.ImportSelector) :: _ if !sel.isGiven =>
completionPrefix(sel.imported :: Nil, pos)

case (tree: untpd.ImportOrExport) :: _ =>
tree.selectors.find(_.span.contains(pos.span)).map { selector =>
tree.selectors.find(_.span.contains(pos.span)).map: selector =>
completionPrefix(selector :: Nil, pos)
}.getOrElse("")
.getOrElse("")

// Foo.`se<TAB> will result in Select(Ident(Foo), <error>)
case (select: untpd.Select) :: _ if select.name == nme.ERROR =>
Expand All @@ -118,27 +123,65 @@ object Completion {
if (ref.name == nme.ERROR) ""
else ref.name.toString.take(pos.span.point - ref.span.point)

case _ =>
""
case _ => ""

end completionPrefix

/** Inspect `path` to determine the offset where the completion result should be inserted. */
def completionOffset(path: List[Tree]): Int =
path match {
case (ref: RefTree) :: _ => ref.span.point
def completionOffset(untpdPath: List[untpd.Tree]): Int =
untpdPath match {
case (ref: untpd.RefTree) :: _ => ref.span.point
case _ => 0
}

private def computeCompletions(pos: SourcePosition, path: List[Tree])(using Context): (Int, List[Completion]) = {
val mode = completionMode(path, pos)
val rawPrefix = completionPrefix(path, pos)
/** Some information about the trees is lost after Typer such as Extension method construct
* is expanded into methods. In order to support completions in those cases
* we have to rely on untyped trees and only when types are necessary use typed trees.
*/
def resolveTypedOrUntypedPath(tpdPath: List[Tree], pos: SourcePosition)(using Context): List[untpd.Tree] =
lazy val untpdPath: List[untpd.Tree] = NavigateAST
.pathTo(pos.span, List(ctx.compilationUnit.untpdTree), true).collect:
case untpdTree: untpd.Tree => untpdTree

tpdPath match
case (_: Bind) :: _ => tpdPath
case (_: untpd.TypTree) :: _ => tpdPath
case _ => untpdPath

/** Handle case when cursor position is inside extension method construct.
* The extension method construct is then desugared into methods, and consturct parameters
* are no longer a part of a typed tree, but instead are prepended to method parameters.
*
* @param untpdPath The typed or untyped path to the tree that is being completed
* @param tpdPath The typed path that will be returned if no extension method construct is found
* @param pos The cursor position
*
* @return Typed path to the parameter of the extension construct if found or tpdPath
*/
private def typeCheckExtensionConstructPath(
untpdPath: List[untpd.Tree], tpdPath: List[Tree], pos: SourcePosition
)(using Context): List[Tree] =
untpdPath.collectFirst:
case untpd.ExtMethods(paramss, _) =>
val enclosingParam = paramss.flatten.find(_.span.contains(pos.span))
enclosingParam.map: param =>
ctx.typer.index(paramss.flatten)
val typedEnclosingParam = ctx.typer.typed(param)
Interactive.pathTo(typedEnclosingParam, pos.span)
.flatten.getOrElse(tpdPath)

private def computeCompletions(pos: SourcePosition, tpdPath: List[Tree])(using Context): (Int, List[Completion]) =
val path0 = resolveTypedOrUntypedPath(tpdPath, pos)
val mode = completionMode(path0, pos)
val rawPrefix = completionPrefix(path0, pos)

val hasBackTick = rawPrefix.headOption.contains('`')
val prefix = if hasBackTick then rawPrefix.drop(1) else rawPrefix

val completer = new Completer(mode, prefix, pos)

val completions = path match {
val adjustedPath = typeCheckExtensionConstructPath(path0, tpdPath, pos)
val completions = adjustedPath match
// Ignore synthetic select from `This` because in code it was `Ident`
// See example in dotty.tools.languageserver.CompletionTest.syntheticThis
case Select(qual @ This(_), _) :: _ if qual.span.isSynthetic => completer.scopeCompletions
Expand All @@ -147,21 +190,19 @@ object Completion {
case (tree: ImportOrExport) :: _ => completer.directMemberCompletions(tree.expr)
case (_: untpd.ImportSelector) :: Import(expr, _) :: _ => completer.directMemberCompletions(expr)
case _ => completer.scopeCompletions
}

val describedCompletions = describeCompletions(completions)
val backtickedCompletions =
describedCompletions.map(completion => backtickCompletions(completion, hasBackTick))

val offset = completionOffset(path)
val offset = completionOffset(path0)

interactiv.println(i"""completion with pos = $pos,
| prefix = ${completer.prefix},
| term = ${completer.mode.is(Mode.Term)},
| type = ${completer.mode.is(Mode.Type)}
| results = $backtickedCompletions%, %""")
(offset, backtickedCompletions)
}

def backtickCompletions(completion: Completion, hasBackTick: Boolean) =
if hasBackTick || needsBacktick(completion.label) then
Expand All @@ -174,17 +215,17 @@ object Completion {
// https://github.com/scalameta/metals/blob/main/mtags/src/main/scala/scala/meta/internal/mtags/KeywordWrapper.scala
// https://github.com/com-lihaoyi/Ammonite/blob/73a874173cd337f953a3edc9fb8cb96556638fdd/amm/util/src/main/scala/ammonite/util/Model.scala
private def needsBacktick(s: String) =
val chunks = s.split("_", -1)
val chunks = s.split("_", -1).nn

val validChunks = chunks.zipWithIndex.forall { case (chunk, index) =>
chunk.forall(Chars.isIdentifierPart) ||
(chunk.forall(Chars.isOperatorPart) &&
chunk.nn.forall(Chars.isIdentifierPart) ||
(chunk.nn.forall(Chars.isOperatorPart) &&
index == chunks.length - 1 &&
!(chunks.lift(index - 1).contains("") && index - 1 == 0))
}

val validStart =
Chars.isIdentifierStart(s(0)) || chunks(0).forall(Chars.isOperatorPart)
Chars.isIdentifierStart(s(0)) || chunks(0).nn.forall(Chars.isOperatorPart)

val valid = validChunks && validStart && !keywords.contains(s)

Expand Down Expand Up @@ -216,7 +257,7 @@ object Completion {
* For the results of all `xyzCompletions` methods term names and type names are always treated as different keys in the same map
* and they never conflict with each other.
*/
class Completer(val mode: Mode, val prefix: String, pos: SourcePosition) {
class Completer(val mode: Mode, val prefix: String, pos: SourcePosition):
/** Completions for terms and types that are currently in scope:
* the members of the current class, local definitions and the symbols that have been imported,
* recursively adding completions from outer scopes.
Expand All @@ -230,7 +271,7 @@ object Completion {
* (even if the import follows it syntactically)
* - a more deeply nested import shadowing a member or a local definition causes an ambiguity
*/
def scopeCompletions(using context: Context): CompletionMap = {
def scopeCompletions(using context: Context): CompletionMap =
val mappings = collection.mutable.Map.empty[Name, List[ScopedDenotations]].withDefaultValue(List.empty)
def addMapping(name: Name, denots: ScopedDenotations) =
mappings(name) = mappings(name) :+ denots
Expand Down Expand Up @@ -302,7 +343,7 @@ object Completion {
}

resultMappings
}
end scopeCompletions

/** Widen only those types which are applied or are exactly nothing
*/
Expand Down Expand Up @@ -335,16 +376,16 @@ object Completion {
/** Completions introduced by imports directly in this context.
* Completions from outer contexts are not included.
*/
private def importedCompletions(using Context): CompletionMap = {
private def importedCompletions(using Context): CompletionMap =
val imp = ctx.importInfo

def fromImport(name: Name, nameInScope: Name): Seq[(Name, SingleDenotation)] =
imp.site.member(name).alternatives
.collect { case denot if include(denot, nameInScope) => nameInScope -> denot }

if imp == null then
Map.empty
else
def fromImport(name: Name, nameInScope: Name): Seq[(Name, SingleDenotation)] =
imp.site.member(name).alternatives
.collect { case denot if include(denot, nameInScope) => nameInScope -> denot }

val givenImports = imp.importedImplicits
.map { ref => (ref.implicitName: Name, ref.underlyingRef.denot.asSingleDenotation) }
.filter((name, denot) => include(denot, name))
Expand All @@ -370,7 +411,7 @@ object Completion {
}.toSeq.groupByName

givenImports ++ wildcardMembers ++ explicitMembers
}
end importedCompletions

/** Completions from implicit conversions including old style extensions using implicit classes */
private def implicitConversionMemberCompletions(qual: Tree)(using Context): CompletionMap =
Expand Down Expand Up @@ -532,7 +573,6 @@ object Completion {
extension [N <: Name](namedDenotations: Seq[(N, SingleDenotation)])
@annotation.targetName("groupByNameTupled")
def groupByName: CompletionMap = namedDenotations.groupMap((name, denot) => name)((name, denot) => denot)
}

private type CompletionMap = Map[Name, Seq[SingleDenotation]]

Expand All @@ -545,11 +585,11 @@ object Completion {
* The completion mode: defines what kinds of symbols should be included in the completion
* results.
*/
class Mode(val bits: Int) extends AnyVal {
class Mode(val bits: Int) extends AnyVal:
def is(other: Mode): Boolean = (bits & other.bits) == other.bits
def |(other: Mode): Mode = new Mode(bits | other.bits)
}
object Mode {

object Mode:
/** No symbol should be included */
val None: Mode = new Mode(0)

Expand All @@ -561,6 +601,4 @@ object Completion {

/** Both term and type symbols are allowed */
val ImportOrExport: Mode = new Mode(4) | Term | Type
}
}

35 changes: 24 additions & 11 deletions compiler/src/dotty/tools/repl/ReplCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ class ReplCompiler extends Compiler:
end compile

final def typeOf(expr: String)(using state: State): Result[String] =
typeCheck(expr).map { tree =>
typeCheck(expr).map { (_, tpdTree) =>
given Context = state.context
tree.rhs match {
tpdTree.rhs match {
case Block(xs, _) => xs.last.tpe.widen.show
case _ =>
"""Couldn't compute the type of your expression, so sorry :(
Expand Down Expand Up @@ -129,7 +129,7 @@ class ReplCompiler extends Compiler:
Iterator(sym) ++ sym.allOverriddenSymbols
}

typeCheck(expr).map {
typeCheck(expr).map { (_, tpdTree) => tpdTree match
case ValDef(_, _, Block(stats, _)) if stats.nonEmpty =>
val stat = stats.last.asInstanceOf[tpd.Tree]
if (stat.tpe.isError) stat.tpe.show
Expand All @@ -152,7 +152,7 @@ class ReplCompiler extends Compiler:
}
}

final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[tpd.ValDef] = {
final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[(untpd.ValDef, tpd.ValDef)] = {

def wrapped(expr: String, sourceFile: SourceFile, state: State)(using Context): Result[untpd.PackageDef] = {
def wrap(trees: List[untpd.Tree]): untpd.PackageDef = {
Expand Down Expand Up @@ -181,22 +181,32 @@ class ReplCompiler extends Compiler:
}
}

def unwrapped(tree: tpd.Tree, sourceFile: SourceFile)(using Context): Result[tpd.ValDef] = {
def error: Result[tpd.ValDef] =
List(new Diagnostic.Error(s"Invalid scala expression",
sourceFile.atSpan(Span(0, sourceFile.content.length)))).errors
def error[Tree <: untpd.Tree](sourceFile: SourceFile): Result[Tree] =
List(new Diagnostic.Error(s"Invalid scala expression",
sourceFile.atSpan(Span(0, sourceFile.content.length)))).errors

def unwrappedTypeTree(tree: tpd.Tree, sourceFile0: SourceFile)(using Context): Result[tpd.ValDef] = {
import tpd._
tree match {
case PackageDef(_, List(TypeDef(_, tmpl: Template))) =>
tmpl.body
.collectFirst { case dd: ValDef if dd.name.show == "expr" => dd.result }
.getOrElse(error)
.getOrElse(error[tpd.ValDef](sourceFile0))
case _ =>
error
error[tpd.ValDef](sourceFile0)
}
}

def unwrappedUntypedTree(tree: untpd.Tree, sourceFile0: SourceFile)(using Context): Result[untpd.ValDef] =
import untpd._
tree match {
case PackageDef(_, List(TypeDef(_, tmpl: Template))) =>
tmpl.body
.collectFirst { case dd: ValDef if dd.name.show == "expr" => dd.result }
.getOrElse(error[untpd.ValDef](sourceFile0))
case _ =>
error[untpd.ValDef](sourceFile0)
}

val src = SourceFile.virtual("<typecheck>", expr)
inContext(state.context.fresh
Expand All @@ -209,7 +219,10 @@ class ReplCompiler extends Compiler:
ctx.run.nn.compileUnits(unit :: Nil, ctx)

if (errorsAllowed || !ctx.reporter.hasErrors)
unwrapped(unit.tpdTree, src)
for
tpdTree <- unwrappedTypeTree(unit.tpdTree, src)
untpdTree <- unwrappedUntypedTree(unit.untpdTree, src)
yield untpdTree -> tpdTree
else
ctx.reporter.removeBufferedMessages.errors
}
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/repl/ReplDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,11 @@ class ReplDriver(settings: Array[String],
given state: State = newRun(state0)
compiler
.typeCheck(expr, errorsAllowed = true)
.map { tree =>
.map { (untpdTree, tpdTree) =>
val file = SourceFile.virtual("<completions>", expr, maybeIncomplete = true)
val unit = CompilationUnit(file)(using state.context)
unit.tpdTree = tree
unit.untpdTree = untpdTree
unit.tpdTree = tpdTree
given Context = state.context.fresh.setCompilationUnit(unit)
val srcPos = SourcePosition(file, Span(cursor))
val completions = try Completion.completions(srcPos)._2 catch case NonFatal(_) => Nil
Expand Down
Loading

0 comments on commit de4ad2b

Please sign in to comment.