diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index 44fa4e9b22fd..7569e16e663f 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -1142,6 +1142,11 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { buf.toList } + def collectSubTrees[A](f: PartialFunction[Tree, A])(using Context): List[A] = + val buf = mutable.ListBuffer[A]() + foreachSubTree(f.runWith(buf += _)(_)) + buf.toList + /** Set this tree as the `defTree` of its symbol and return this tree */ def setDefTree(using Context): ThisTree = { val sym = tree.symbol diff --git a/compiler/src/dotty/tools/dotc/core/Symbols.scala b/compiler/src/dotty/tools/dotc/core/Symbols.scala index b45702a511bd..000a2b5fcda8 100644 --- a/compiler/src/dotty/tools/dotc/core/Symbols.scala +++ b/compiler/src/dotty/tools/dotc/core/Symbols.scala @@ -19,9 +19,7 @@ import DenotTransformers.* import StdNames.* import NameOps.* import NameKinds.LazyImplicitName -import ast.tpd -import tpd.{Tree, TreeProvider, TreeOps} -import ast.TreeTypeMap +import ast.*, tpd.* import Constants.Constant import Variances.Variance import reporting.Message @@ -336,6 +334,15 @@ object Symbols extends SymUtils { denot.info.dropAlias.finalResultType.typeConstructor match case tp: NamedType => tp.symbol.sourceSymbol case _ => this + else if denot.is(ExportedTerm) then + val root = denot.maybeOwner match + case cls: ClassSymbol => cls.rootTreeContaining(name.toString) + case _ => EmptyTree + val targets = root.collectSubTrees: + case tree: DefDef if tree.name == name => methPart(tree.rhs).tpe + targets.match + case (tp: NamedType) :: _ => tp.symbol.sourceSymbol + case _ => this else if (denot.is(Synthetic)) { val linked = denot.linkedClass if (linked.exists && !linked.is(Synthetic)) diff --git a/presentation-compiler/src/main/dotty/tools/pc/MetalsInteractive.scala b/presentation-compiler/src/main/dotty/tools/pc/MetalsInteractive.scala index c213acef5fe8..2c2897e401a1 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/MetalsInteractive.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/MetalsInteractive.scala @@ -1,19 +1,14 @@ -package dotty.tools.pc +package dotty.tools +package pc import scala.annotation.tailrec -import dotty.tools.dotc.ast.tpd -import dotty.tools.dotc.ast.tpd.* -import dotty.tools.dotc.ast.untpd -import dotty.tools.dotc.core.Contexts.* -import dotty.tools.dotc.core.Flags.* -import dotty.tools.dotc.core.Names.Name -import dotty.tools.dotc.core.StdNames -import dotty.tools.dotc.core.Symbols.* -import dotty.tools.dotc.core.Types.Type -import dotty.tools.dotc.interactive.SourceTree -import dotty.tools.dotc.util.SourceFile -import dotty.tools.dotc.util.SourcePosition +import dotc.* +import ast.*, tpd.* +import core.*, Contexts.*, Decorators.*, Flags.*, Names.*, Symbols.*, Types.* +import interactive.* +import util.* +import util.SourcePosition object MetalsInteractive: @@ -205,7 +200,7 @@ object MetalsInteractive: Nil case path @ head :: tail => - if head.symbol.is(ExportedType) then + if head.symbol.is(Exported) then val sym = head.symbol.sourceSymbol List((sym, sym.info)) else if head.symbol.is(Synthetic) then diff --git a/presentation-compiler/test/dotty/tools/pc/tests/definition/PcDefinitionSuite.scala b/presentation-compiler/test/dotty/tools/pc/tests/definition/PcDefinitionSuite.scala index f763702cffdf..c05545124495 100644 --- a/presentation-compiler/test/dotty/tools/pc/tests/definition/PcDefinitionSuite.scala +++ b/presentation-compiler/test/dotty/tools/pc/tests/definition/PcDefinitionSuite.scala @@ -199,89 +199,67 @@ class PcDefinitionSuite extends BasePcDefinitionSuite: |""".stripMargin ) - @Test def `exportType1` = - check( - """object enumerations: - | trait <> - | trait CymbalKind - | - |object all: - | export enumerations.* - | - |@main def hello = - | import all.SymbolKind - | import enumerations.CymbalKind - | - | val x = new Symbo@@lKind {} - | val y = new CymbalKind {} + @Test def exportType0 = + check( + """object Foo: + | trait <> + |object Bar: + | export Foo.* + |class Test: + | import Bar.* + | def test = new Ca@@t {} |""".stripMargin ) - @Test def `exportType1Wild` = - check( - """object enumerations: - | trait <> - | trait CymbalKind - | - |object all: - | export enumerations.SymbolKind - | - |@main def hello = - | import all.SymbolKind - | import enumerations.CymbalKind - | - | val x = new Symbo@@lKind {} - | val y = new CymbalKind {} + @Test def exportType1 = + check( + """object Foo: + | trait <>[A] + |object Bar: + | export Foo.* + |class Test: + | import Bar.* + | def test = new Ca@@t[Int] {} |""".stripMargin ) - @Test def `exportTerm1` = + @Test def exportTerm0Nullary = check( - """class BitMap - |class Scanner: - | def scan(): BitMap = ??? - |class Copier: - | private val scanUnit = new Scanner - | export scanUnit.<> - | def t1 = sc@@an() + """trait Foo: + | def <>: Int + |class Bar(val foo: Foo): + | export foo.* + | def test(bar: Bar) = bar.me@@th |""".stripMargin ) - @Test def `exportTerm2` = + @Test def exportTerm0 = check( - """class BitMap - |class Scanner: - | def scan(): BitMap = ??? - |class Copier: - | private val scanUnit = new Scanner - | export scanUnit.<> - |class Test: - | def t2(cpy: Copier) = cpy.sc@@an() + """trait Foo: + | def <>(): Int + |class Bar(val foo: Foo): + | export foo.* + | def test(bar: Bar) = bar.me@@th() |""".stripMargin ) - @Test def `exportTerm1Wild` = + @Test def exportTerm1 = check( - """class BitMap - |class Scanner: - | def scan(): BitMap = ??? - |class Copier: - | private val scanUnit = new Scanner - | export scanUnit.<<*>> - | def t1 = sc@@an() + """trait Foo: + | def <>(x: Int): Int + |class Bar(val foo: Foo): + | export foo.* + | def test(bar: Bar) = bar.me@@th(0) |""".stripMargin ) - @Test def `exportTerm2Wild` = + @Test def exportTerm1Poly = check( - """class BitMap - |class Scanner: - | def scan(): BitMap = ??? - |class Copier: - | private val scanUnit = new Scanner - | export scanUnit.<<*>> - |class Test: - | def t2(cpy: Copier) = cpy.sc@@an() + """trait Foo: + | def <>[A](x: A): A + |class Bar(val foo: Foo): + | export foo.* + | def test(bar: Bar) = bar.me@@th(0) |""".stripMargin )