Skip to content

Commit

Permalink
implementation and tests for unrolling non-first parameter list
Browse files Browse the repository at this point in the history
  • Loading branch information
lihaoyi committed Feb 17, 2024
1 parent 9264ed6 commit 3cf5c5c
Show file tree
Hide file tree
Showing 12 changed files with 200 additions and 29 deletions.
1 change: 1 addition & 0 deletions build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ trait UnrollModule extends Cross.Module[String]{
"primaryConstructor",
"secondaryConstructor",
"caseclass",
"secondParameterList",
// "abstractTraitMethod",
// "abstractClassMethod"
)
Expand Down
53 changes: 33 additions & 20 deletions unroll/plugin/src-2/UnrollPhaseScala2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,17 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
}
}

val forwarderMethodType = defdef.symbol.tpe match{
def forwarderMethodType0(t: Type, n: Int): Type = t match{
case MethodType(originalParams, result) =>
val forwarderParams = originalParams.map(symbolReplacements)
MethodType(forwarderParams.take(paramIndex), result)
if (n == annotatedParamListIndex) MethodType(forwarderParams.take(paramIndex), result)
else MethodType(forwarderParams, forwarderMethodType0(result, n + 1))

case PolyType(tparams, MethodType(originalParams, result)) =>
val forwarderParams = originalParams.map(symbolReplacements)
PolyType(tparams, MethodType(forwarderParams.take(paramIndex), result))
case PolyType(tparams, res) => PolyType(tparams, forwarderMethodType0(res, n))
}

val forwarderMethodType = forwarderMethodType0(defdef.symbol.tpe, 0)

forwarderDefSymbol.setInfo(forwarderMethodType)

val newParamLists = paramLists
Expand All @@ -80,37 +81,49 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
}
.map(_.map(copyValDef))

val defaultCalls = Range(paramIndex, paramLists(annotatedParamListIndex).size).map{n =>
val mangledName = defdef.name.toString + "$default$" + (n + 1)

val defaultOwner =
if (defdef.symbol.isConstructor) implDef.symbol.companionModule
else implDef.symbol

val defaultMember = defaultOwner.tpe.member(TermName(scala.reflect.NameTransformer.encode(mangledName)))
Ident(mangledName).setSymbol(defaultMember).set(defaultMember)
}
val defaultOffset = paramLists
.iterator
.take(annotatedParamListIndex)
.map(_.size)
.sum

val forwardedValueParams = newParamLists(annotatedParamListIndex).map(p => Ident(p.name).set(p.symbol))

val forwarderThis = This(defdef.symbol.owner).set(defdef.symbol.owner)

val forwarderInner =
if (defdef.symbol.isConstructor) Super(forwarderThis, typeNames.EMPTY).set(defdef.symbol.owner)
else forwarderThis

val nestedForwarderMethodTypes = Seq
.iterate(defdef.symbol.tpe, defdef.vparamss.length + 1){
case MethodType(args, res) => res
case PolyType(tparams, MethodType(args, res)) => res
}
.drop(1)

val defaultCalls = Range(paramIndex, paramLists(annotatedParamListIndex).size).map{n =>
val mangledName = defdef.name.toString + "$default$" + (defaultOffset + n + 1)

val defaultOwner =
if (defdef.symbol.isConstructor) implDef.symbol.companionModule
else implDef.symbol

val defaultMember = defaultOwner.tpe.member(TermName(scala.reflect.NameTransformer.encode(mangledName)))
newParamLists.take(annotatedParamListIndex).map(_.map( p => Ident(p.name).set(p.symbol)))
.zip(nestedForwarderMethodTypes)
.foldLeft(Ident(mangledName).setSymbol(defaultMember).set(defaultMember).set(defaultMember): Tree) {
case (lhs, (ps, methodType)) => Apply(fun = lhs, args = ps).setType(methodType)
}

}

val forwarderCallArgs = newParamLists.zipWithIndex.map{case (v, i) =>
if (i == annotatedParamListIndex) forwardedValueParams ++ defaultCalls
else v.map( p => Ident(p.name).set(p.symbol))
}

val forwarderThis = This(defdef.symbol.owner).set(defdef.symbol.owner)

val forwarderInner =
if (defdef.symbol.isConstructor) Super(forwarderThis, typeNames.EMPTY).set(defdef.symbol.owner)
else forwarderThis

val forwarderCall0 = forwarderCallArgs
.zip(nestedForwarderMethodTypes)
.foldLeft(Select(forwarderInner, defdef.name).set(defdef.symbol): Tree){
Expand Down
37 changes: 28 additions & 9 deletions unroll/plugin/src-3/UnrollPhaseScala3.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class UnrollPhaseScala3() extends PluginPhase {
)
}

def isTypeClause(p: ParamClause) = p.headOption.exists(_.isInstanceOf[TypeDef])
def generateSingleForwarder(defdef: DefDef,
prevMethodType: Type,
paramIndex: Int,
Expand All @@ -48,14 +49,16 @@ class UnrollPhaseScala3() extends PluginPhase {
isCaseApply: Boolean)
(using Context) = {

def truncateMethodType0(tpe: Type): Type = {
def truncateMethodType0(tpe: Type, n: Int): Type = {
tpe match{
case pt: PolyType => PolyType(pt.paramNames, pt.paramInfos, truncateMethodType0(pt.resType))
case mt: MethodType => MethodType(mt.paramInfos.take(paramIndex), mt.resType)
case pt: PolyType => PolyType(pt.paramNames, pt.paramInfos, truncateMethodType0(pt.resType, n + 1))
case mt: MethodType =>
if (n == annotatedParamListIndex) MethodType(mt.paramInfos.take(paramIndex), mt.resType)
else MethodType(mt.paramInfos, truncateMethodType0(mt.resType, n + 1))
}
}

val truncatedMethodType = truncateMethodType0(prevMethodType)
val truncatedMethodType = truncateMethodType0(prevMethodType, 0)
val forwarderDefSymbol = Symbols.newSymbol(
defdef.symbol.owner,
defdef.name,
Expand All @@ -66,22 +69,38 @@ class UnrollPhaseScala3() extends PluginPhase {
val newParamLists: List[ParamClause] = paramLists.zipWithIndex.map{ case (ps, i) =>
if (i == annotatedParamListIndex) ps.take(paramIndex).map(p => copyParam(p.asInstanceOf[ValDef], forwarderDefSymbol))
else {
if (ps.headOption.exists(_.isInstanceOf[TypeDef])) ps.map(p => copyParam2(p.asInstanceOf[TypeDef], forwarderDefSymbol))
if (isTypeClause(ps)) ps.map(p => copyParam2(p.asInstanceOf[TypeDef], forwarderDefSymbol))
else ps.map(p => copyParam(p.asInstanceOf[ValDef], forwarderDefSymbol))
}
}

val defaultOffset = paramLists
.iterator
.take(annotatedParamListIndex)
.filter(!isTypeClause(_))
.map(_.size)
.sum

val defaultCalls = Range(paramIndex, paramLists(annotatedParamListIndex).size).map(n =>
if (defdef.symbol.isConstructor) {
val inner = if (defdef.symbol.isConstructor) {
ref(defdef.symbol.owner.companionModule)
.select(DefaultGetterName(defdef.name, n))
.select(DefaultGetterName(defdef.name, n + defaultOffset))
} else if (isCaseApply) {
ref(defdef.symbol.owner.companionModule)
.select(DefaultGetterName(termName("<init>"), n))
.select(DefaultGetterName(termName("<init>"), n + defaultOffset))
} else {
This(defdef.symbol.owner.asClass)
.select(DefaultGetterName(defdef.name, n))
.select(DefaultGetterName(defdef.name, n + defaultOffset))
}

newParamLists
.take(annotatedParamListIndex)
.map(_.map(p => ref(p.symbol)))
.foldLeft[Tree](inner){
case (lhs: Tree, newParams) =>
if (newParams.headOption.exists(_.isInstanceOf[TypeTree])) TypeApply(lhs, newParams)
else Apply(lhs, newParams)
}
)

val forwarderInner: Tree = This(defdef.symbol.owner.asClass).select(defdef.symbol)
Expand Down
5 changes: 5 additions & 0 deletions unroll/tests/secondParameterList/v1/src/Unrolled.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package unroll

class Unrolled{
def foo(f: String => String)(s: String) = f(s)
}
23 changes: 23 additions & 0 deletions unroll/tests/secondParameterList/v1/test/src/UnrollTestMain.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package unroll

import unroll.TestUtils.logAssertStartsWith

object UnrollTestMain{
def main(args: Array[String]): Unit = {
logAssertStartsWith(new Unrolled().foo(identity)("cow"), "cow")
}
}














7 changes: 7 additions & 0 deletions unroll/tests/secondParameterList/v2/src/Unrolled.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package unroll

import scala.annotation.unroll

class Unrolled{
def foo(f: String => String)(s: String, @unroll n: Int = 1, b: Boolean = true) = f(s + n + b)
}
25 changes: 25 additions & 0 deletions unroll/tests/secondParameterList/v2/test/src/UnrollTestMain.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package unroll

import unroll.TestUtils.logAssertStartsWith

object UnrollTestMain{
def main(args: Array[String]): Unit = {
logAssertStartsWith(new Unrolled().foo(identity)("cow"), "cow1true")
logAssertStartsWith(new Unrolled().foo(identity)("cow", 2), "cow2true")
logAssertStartsWith(new Unrolled().foo(identity)("cow", 2, false), "cow2false")
}
}














10 changes: 10 additions & 0 deletions unroll/tests/secondParameterList/v3/src/Unrolled.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package unroll

import scala.annotation.unroll

class Unrolled{
def foo(f: String => String)(s: String, @unroll n: Int = 1, b: Boolean = true, @unroll l: Long = 0) = f(s + n + b + l)
}



Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package unroll

object UnrollTestPlatformSpecific{
def apply() = {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package unroll

object UnrollTestPlatformSpecific{
def apply() = {
val instance = new Unrolled()
val cls = classOf[Unrolled]

assert(
cls.getMethod("foo", classOf[String => String], classOf[String])
.invoke(instance, identity[String](_), "hello") ==
"hello1true0"
)

assert(
scala.util.Try(cls.getMethod("foo", classOf[String => String], classOf[String], classOf[Int])).isFailure
)
assert(
cls.getMethod("foo", classOf[String => String], classOf[String], classOf[Int], classOf[Boolean])
.invoke(instance, identity[String](_), "hello", 2: Integer, java.lang.Boolean.FALSE) ==
"hello2false0"
)
assert(
cls.getMethod("foo", classOf[String => String], classOf[String], classOf[Int], classOf[Boolean], classOf[Long])
.invoke(instance, identity[String](_), "hello", 2: Integer, java.lang.Boolean.FALSE, 3: Integer) ==
"hello2false3"
)

cls.getMethods.filter(_.getName.contains("foo")).foreach(println)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package unroll

object UnrollTestPlatformSpecific{
def apply() = {}
}
28 changes: 28 additions & 0 deletions unroll/tests/secondParameterList/v3/test/src/UnrollTestMain.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package unroll

import unroll.TestUtils.logAssertStartsWith

object UnrollTestMain{
def main(args: Array[String]): Unit = {
UnrollTestPlatformSpecific()

logAssertStartsWith(new Unrolled().foo(identity)("cow"), "cow1true0")
logAssertStartsWith(new Unrolled().foo(identity)("cow", 2), "cow2true0")
logAssertStartsWith(new Unrolled().foo(identity)("cow", 2, false), "cow2false0")
logAssertStartsWith(new Unrolled().foo(identity)("cow", 2, false, 3), "cow2false3")
}
}














0 comments on commit 3cf5c5c

Please sign in to comment.