Skip to content

Commit

Permalink
incremental forwarders in preparation for abstract method unrolling
Browse files Browse the repository at this point in the history
  • Loading branch information
lihaoyi committed Feb 19, 2024
1 parent 3cf5c5c commit a339119
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 52 deletions.
66 changes: 66 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,69 @@ You can also run the following command to run all tests:
This can be useful as a final sanity check, even though you usually want to run
a subset of the tests specific to the `scala-version` and `test-case` you are
interested in.

```scala
trait Upstream{
def foo(s: String, n: Int = 1)
}
```
```scala
trait Upstream{
// source
def foo(s: String, n: Int = 1, @unroll b: Boolean = true)

// generated
def foo(s: String, n: Int = 1, b: Boolean = true) = foo(s, n)
def fooUp(s: String, n: Int, b: Boolean) = foo(s, n, b)
def fooDown(s: String, n: Int, b: Boolean) = foo(s, n, b)
def foo(s: String, n: Int) = foo(s, n, true)
}
```

```scala
trait Upstream{
// source
def foo(s: String, n: Int = 1, @unroll b: Boolean = true, @unroll l: Long = 0)

// generated
def foo(s: String, n: Int = 1, b: Boolean = true, l: Long = 0) = fooDown(s, n, b)
def fooUp(s: String, n: Int, b: Boolean) = foo(s, n, b, 0)
def fooDown(s: String, n: Int, b: Boolean) = foo(s, n)
def foo(s: String, n: Int) = fooUp(s, n, true)
}
```
```scala
trait Downstream extends Upstream{
final def foo(s: String, n: Int = 1) = println(s + n)
}
```
```scala
trait Downstream extends Upstream{
// source
def foo(s: String, n: Int = 1, b: Boolean = true) = println(s + n + b)

// generated
final def foo(s: String, n: Int = 1, b: Boolean = true) = foo(s, n)
final def fooUp(s: String, n: Int, b: Boolean) = foo(s, n, b)
final def fooDown(s: String, n: Int, b: Boolean) = foo(s, n, b)

final def foo(s: String, n: Int) = foo(s, n, true)
}
```

```scala
trait Downstream extends Upstream{
// source
def foo(s: String, n: Int = 1, b: Boolean = true, l: Long = 0)

// generated
final def foo(s: String, n: Int = 1, b: Boolean = true, l: Long = 0) = fooDown(s, n, b)
final def fooUp(s: String, n: Int, b: Boolean, l: Long) = fooDown(s, n, b)
final def fooDown(s: String, n: Int, b: Boolean, l: Long) = fooDown(s, n, b)

final def foo(s: String, n: Int, b: Boolean) = foo(s, n, b, 0)
final def fooUp(s: String, n: Int, b: Boolean) = foo(s, n, b, 0)
final def fooDown(s: String, n: Int, b: Boolean) = foo(s, n)
final def foo(s: String, n: Int) = fooUp(s, n, true)
}
```
5 changes: 3 additions & 2 deletions build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ trait UnrollModule extends Cross.Module[String]{
def moduleDeps = Seq(annotation)
def run(args: Task[Args] = T.task(Args())) = T.command{/*donothing*/}
def mimaPreviousArtifacts = T.traverse(mimaPrevious)(_.jvm.jar)()
override def scalacPluginClasspath = T{ Agg(plugin.jar()) }
def scalacPluginClasspath = T{ Agg(plugin.jar()) }

// override def scalaCompilerClasspath = T{
// super.scalaCompilerClasspath().filter(!_.toString().contains("scala3-compiler")) ++
// Agg(PathRef(os.Path("/Users/lihaoyi/.ivy2/local/org.scala-lang/scala3-compiler_3/3.3.2-RC3-bin-SNAPSHOT/jars/scala3-compiler_3.jar")))
// }
override def scalacOptions = T{
def scalacOptions = T{
Seq(
s"-Xplugin:${plugin.jar().path}",
"-Xplugin-require:unroll",
Expand All @@ -109,6 +109,7 @@ trait UnrollModule extends Cross.Module[String]{
def moduleDeps = Seq(Unrolled.this)
def mainClass = Some("unroll.UnrollTestMain")
def testFramework = T{ "" } // stub
def scalacOptions = Seq.empty[String]
}
}

Expand Down
31 changes: 19 additions & 12 deletions unroll/plugin/src-2/UnrollPhaseScala2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,11 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
def generateSingleForwarder(implDef: ImplDef,
defdef: DefDef,
paramIndex: Int,
nextParamIndex: Int,
nextSymbol: Symbol,
annotatedParamListIndex: Int,
paramLists: List[List[ValDef]]) = {

val forwarderDefSymbol = defdef.symbol.owner.newMethod(defdef.name)
val symbolReplacements = defdef
.vparamss
Expand Down Expand Up @@ -91,13 +94,13 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
val forwardedValueParams = newParamLists(annotatedParamListIndex).map(p => Ident(p.name).set(p.symbol))

val nestedForwarderMethodTypes = Seq
.iterate(defdef.symbol.tpe, defdef.vparamss.length + 1){
.iterate(nextSymbol.tpe, defdef.vparamss.length + 1){
case MethodType(args, res) => res
case PolyType(tparams, MethodType(args, res)) => res
}
.drop(1)

val defaultCalls = Range(paramIndex, paramLists(annotatedParamListIndex).size).map{n =>
val defaultCalls = Range(paramIndex, nextParamIndex).map{n =>
val mangledName = defdef.name.toString + "$default$" + (defaultOffset + n + 1)

val defaultOwner =
Expand Down Expand Up @@ -126,7 +129,7 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT

val forwarderCall0 = forwarderCallArgs
.zip(nestedForwarderMethodTypes)
.foldLeft(Select(forwarderInner, defdef.name).set(defdef.symbol): Tree){
.foldLeft(Select(forwarderInner, defdef.name).set(nextSymbol): Tree){
case (lhs, (ps, methodType)) => Apply(fun = lhs, args = ps).setType(methodType)
}

Expand Down Expand Up @@ -183,15 +186,19 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
} match{
case Nil => Nil
case Seq((annotatedParamList, annotationIndices, paramListIndex)) =>
for (paramIndex <- annotationIndices) yield {
generateSingleForwarder(
implDef,
defdef,
paramIndex,
paramListIndex,
defdef.vparamss
)
}
(annotationIndices :+ annotatedParamList.length).sliding(2).toList.reverse.foldLeft((Seq.empty[DefDef], defdef.symbol)){
case ((defdefs, nextSymbol), Seq(paramIndex, nextParamIndex)) =>
val forwarderDef = generateSingleForwarder(
implDef,
defdef,
paramIndex,
nextParamIndex,
nextSymbol,
paramListIndex,
defdef.vparamss
)
(forwarderDef +: defdefs, forwarderDef.symbol)
}._1

case multiple => sys.error("Multiple")
}
Expand Down
72 changes: 37 additions & 35 deletions unroll/plugin/src-3/UnrollPhaseScala3.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,20 @@ class UnrollPhaseScala3() extends PluginPhase {
)
}

def findUnrollAnnotations(params: List[Symbol])(using Context): List[Int] = {
params
.zipWithIndex
.collect {
case (v, i) if v.annotations.exists(_.symbol.fullName.toString == "scala.annotation.unroll") =>
i
}
}
def isTypeClause(p: ParamClause) = p.headOption.exists(_.isInstanceOf[TypeDef])
def generateSingleForwarder(defdef: DefDef,
prevMethodType: Type,
paramIndex: Int,
nextParamIndex: Int,
nextSymbol: Symbol,
annotatedParamListIndex: Int,
paramLists: List[ParamClause],
isCaseApply: Boolean)
Expand Down Expand Up @@ -81,7 +91,7 @@ class UnrollPhaseScala3() extends PluginPhase {
.map(_.size)
.sum

val defaultCalls = Range(paramIndex, paramLists(annotatedParamListIndex).size).map(n =>
val defaultCalls = Range(paramIndex, nextParamIndex).map(n =>
val inner = if (defdef.symbol.isConstructor) {
ref(defdef.symbol.owner.companionModule)
.select(DefaultGetterName(defdef.name, n + defaultOffset))
Expand All @@ -103,7 +113,7 @@ class UnrollPhaseScala3() extends PluginPhase {
}
)

val forwarderInner: Tree = This(defdef.symbol.owner.asClass).select(defdef.symbol)
val forwarderInner: Tree = This(defdef.symbol.owner.asClass).select(nextSymbol)

val forwarderCallArgs =
newParamLists.zipWithIndex.map{case (ps, i) =>
Expand Down Expand Up @@ -184,49 +194,41 @@ class UnrollPhaseScala3() extends PluginPhase {
else if (isCaseFromProduct) defdef.symbol.owner.companionClass.primaryConstructor
else defdef.symbol

val annotatedParamListIndex = annotated.paramSymss.indexWhere(!_.headOption.exists(_.isType))

if (annotatedParamListIndex == -1) (None, Nil)
else {
val paramCount = annotated.paramSymss(annotatedParamListIndex).size
annotated
.paramSymss
.zipWithIndex
.map{case (paramClause, paramClauseIndex) =>
annotated
.paramSymss
.zipWithIndex
.flatMap{case (paramClause, paramClauseIndex) =>
val annotationIndices = findUnrollAnnotations(paramClause)
if (annotationIndices.isEmpty) None
else Some((paramClauseIndex, annotationIndices))
} match{
case Nil => (None, Nil)
case Seq((paramClauseIndex, annotationIndices)) =>
val paramCount = annotated.paramSymss(paramClauseIndex).size
if (isCaseFromProduct) {
(Some(defdef.symbol), Seq(generateFromProduct(annotationIndices, paramCount, defdef)))
} else {
(
paramClauseIndex,
paramClause
.zipWithIndex
.collect {
case (v, i) if v.annotations.exists(_.symbol.fullName.toString == "scala.annotation.unroll") =>
i
}
)
}
.filter{case (paramClauseIndex, annotationIndices) => annotationIndices.nonEmpty } match{
case Nil => (None, Nil)
case Seq((paramClauseIndex, annotationIndices)) =>
if (isCaseFromProduct) {
(Some(defdef.symbol), Seq(generateFromProduct(annotationIndices, paramCount, defdef)))
} else {
(
None,

for (paramIndex <- annotationIndices) yield {
generateSingleForwarder(
None,
(annotationIndices :+ paramCount).sliding(2).toList.reverse.foldLeft((Seq.empty[DefDef], defdef.symbol)){
case ((defdefs, nextSymbol), Seq(paramIndex, nextParamIndex)) =>
val forwarder = generateSingleForwarder(
defdef,
defdef.symbol.info,
paramIndex,
nextParamIndex,
nextSymbol,
paramClauseIndex,
defdef.paramss,
isCaseApply
)
}
)
}
(forwarder +: defdefs, forwarder.symbol)
}._1
)
}

case multiple => sys.error("Cannot have multiple parameter lists containing `@unroll` annotation")
}
case multiple => sys.error("Cannot have multiple parameter lists containing `@unroll` annotation")
}

case _ => (None, Nil)
Expand Down
3 changes: 0 additions & 3 deletions unroll/tests/secondParameterList/v3/src/Unrolled.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,3 @@ import scala.annotation.unroll
class Unrolled{
def foo(f: String => String)(s: String, @unroll n: Int = 1, b: Boolean = true, @unroll l: Long = 0) = f(s + n + b + l)
}



0 comments on commit a339119

Please sign in to comment.