Skip to content

Commit

Permalink
sketch of non-first-parameter-list support in scala2
Browse files Browse the repository at this point in the history
  • Loading branch information
lihaoyi committed Feb 17, 2024
1 parent 2a5cacc commit 9264ed6
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 36 deletions.
47 changes: 30 additions & 17 deletions unroll/plugin/src-2/UnrollPhaseScala2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
def generateSingleForwarder(implDef: ImplDef,
defdef: DefDef,
paramIndex: Int,
firstParamList: List[ValDef],
otherParamLists: List[List[ValDef]]) = {
annotatedParamListIndex: Int,
paramLists: List[List[ValDef]]) = {
val forwarderDefSymbol = defdef.symbol.owner.newMethod(defdef.name)
val symbolReplacements = defdef
.vparamss
Expand Down Expand Up @@ -72,12 +72,15 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT

forwarderDefSymbol.setInfo(forwarderMethodType)

val newVParamss =
List(firstParamList.take(paramIndex).map(copyValDef)) ++ otherParamLists.map(_.map(copyValDef))

val forwardedValueParams = firstParamList.take(paramIndex).map(p => Ident(p.name).set(p.symbol))
val newParamLists = paramLists
.zipWithIndex
.map{ case (paramList, i) =>
if (i != annotatedParamListIndex) paramList
else paramList.take(paramIndex)
}
.map(_.map(copyValDef))

val defaultCalls = Range(paramIndex, firstParamList.size).map{n =>
val defaultCalls = Range(paramIndex, paramLists(annotatedParamListIndex).size).map{n =>
val mangledName = defdef.name.toString + "$default$" + (n + 1)

val defaultOwner =
Expand All @@ -88,6 +91,8 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
Ident(mangledName).setSymbol(defaultMember).set(defaultMember)
}

val forwardedValueParams = newParamLists(annotatedParamListIndex).map(p => Ident(p.name).set(p.symbol))

val forwarderThis = This(defdef.symbol.owner).set(defdef.symbol.owner)

val forwarderInner =
Expand All @@ -101,9 +106,10 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
}
.drop(1)

val forwarderCallArgs =
Seq(forwardedValueParams ++ defaultCalls) ++
newVParamss.tail.map(_.map( p => Ident(p.name).set(p.symbol)))
val forwarderCallArgs = newParamLists.zipWithIndex.map{case (v, i) =>
if (i == annotatedParamListIndex) forwardedValueParams ++ defaultCalls
else v.map( p => Ident(p.name).set(p.symbol))
}

val forwarderCall0 = forwarderCallArgs
.zip(nestedForwarderMethodTypes)
Expand All @@ -120,7 +126,7 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
mods = defdef.mods,
name = defdef.name,
tparams = defdef.tparams,
vparamss = newVParamss,
vparamss = newParamLists,
tpt = defdef.tpt,
rhs = forwarderCall
).set(forwarderDefSymbol)
Expand Down Expand Up @@ -155,19 +161,26 @@ class UnrollPhaseScala2(val global: Global) extends PluginComponent with TypingT
// do not have companion class primary constructor symbols, so we just skip them here
annotatedOpt.toList.flatMap{ annotated =>
try {
defdef.vparamss match {
annotated.tpe.paramss
.zipWithIndex
.flatMap{case (annotatedParamList, paramListIndex) =>
val annotationIndices = findUnrollAnnotations(annotatedParamList)
if (annotationIndices.isEmpty) None
else Some((annotatedParamList, annotationIndices, paramListIndex))
} match{
case Nil => Nil
case firstParamList :: otherParamLists =>
val annotations = findUnrollAnnotations(annotated.tpe.params)
for (paramIndex <- annotations) yield {
case Seq((annotatedParamList, annotationIndices, paramListIndex)) =>
for (paramIndex <- annotationIndices) yield {
generateSingleForwarder(
implDef,
defdef,
paramIndex,
firstParamList,
otherParamLists
paramListIndex,
defdef.vparamss
)
}

case multiple => sys.error("Multiple")
}
}catch{case e: Throwable =>
throw new Exception(
Expand Down
38 changes: 19 additions & 19 deletions unroll/plugin/src-3/UnrollPhaseScala3.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ class UnrollPhaseScala3() extends PluginPhase {

def generateSingleForwarder(defdef: DefDef,
prevMethodType: Type,
paramLists: List[ParamClause],
firstValueParamClauseIndex: Int,
paramIndex: Int,
annotatedParamListIndex: Int,
paramLists: List[ParamClause],
isCaseApply: Boolean)
(using Context) = {

Expand All @@ -63,15 +63,15 @@ class UnrollPhaseScala3() extends PluginPhase {
truncatedMethodType
)

val updated: List[ParamClause] = paramLists.zipWithIndex.map{ case (ps, i) =>
if (i == firstValueParamClauseIndex) ps.take(paramIndex).map(p => copyParam(p.asInstanceOf[ValDef], forwarderDefSymbol))
val newParamLists: List[ParamClause] = paramLists.zipWithIndex.map{ case (ps, i) =>
if (i == annotatedParamListIndex) ps.take(paramIndex).map(p => copyParam(p.asInstanceOf[ValDef], forwarderDefSymbol))
else {
if (ps.headOption.exists(_.isInstanceOf[TypeDef])) ps.map(p => copyParam2(p.asInstanceOf[TypeDef], forwarderDefSymbol))
else ps.map(p => copyParam(p.asInstanceOf[ValDef], forwarderDefSymbol))
}
}

val defaultCalls = Range(paramIndex, paramLists(firstValueParamClauseIndex).size).map(n =>
val defaultCalls = Range(paramIndex, paramLists(annotatedParamListIndex).size).map(n =>
if (defdef.symbol.isConstructor) {
ref(defdef.symbol.owner.companionModule)
.select(DefaultGetterName(defdef.name, n))
Expand All @@ -84,15 +84,15 @@ class UnrollPhaseScala3() extends PluginPhase {
}
)

val allNewParamTrees =
updated.zipWithIndex.map{case (ps, i) =>
if (i == firstValueParamClauseIndex) ps.map(p => ref(p.symbol)) ++ defaultCalls
val forwarderInner: Tree = This(defdef.symbol.owner.asClass).select(defdef.symbol)

val forwarderCallArgs =
newParamLists.zipWithIndex.map{case (ps, i) =>
if (i == annotatedParamListIndex) ps.map(p => ref(p.symbol)) ++ defaultCalls
else ps.map(p => ref(p.symbol))
}

val forwarderInner: Tree = This(defdef.symbol.owner.asClass).select(defdef.symbol)

val forwarderCall0 = allNewParamTrees.foldLeft[Tree](forwarderInner){
val forwarderCall0 = forwarderCallArgs.foldLeft[Tree](forwarderInner){
case (lhs: Tree, newParams) =>
if (newParams.headOption.exists(_.isInstanceOf[TypeTree])) TypeApply(lhs, newParams)
else Apply(lhs, newParams)
Expand All @@ -102,17 +102,17 @@ class UnrollPhaseScala3() extends PluginPhase {
if (!defdef.symbol.isConstructor) forwarderCall0
else Block(List(forwarderCall0), Literal(Constant(())))

val newDefDef = implicitly[Context].typeAssigner.assignType(
val forwarderDef = implicitly[Context].typeAssigner.assignType(
cpy.DefDef(defdef)(
name = forwarderDefSymbol.name,
paramss = updated,
paramss = newParamLists,
tpt = defdef.tpt,
rhs = forwarderCall
),
forwarderDefSymbol
)

newDefDef
forwarderDef
}

def generateFromProduct(startParamIndices: List[Int], paramCount: Int, defdef: DefDef)(using Context) = {
Expand Down Expand Up @@ -165,11 +165,11 @@ class UnrollPhaseScala3() extends PluginPhase {
else if (isCaseFromProduct) defdef.symbol.owner.companionClass.primaryConstructor
else defdef.symbol

val firstValueParamClauseIndex = annotated.paramSymss.indexWhere(!_.headOption.exists(_.isType))
val annotatedParamListIndex = annotated.paramSymss.indexWhere(!_.headOption.exists(_.isType))

if (firstValueParamClauseIndex == -1) (None, Nil)
if (annotatedParamListIndex == -1) (None, Nil)
else {
val paramCount = annotated.paramSymss(firstValueParamClauseIndex).size
val paramCount = annotated.paramSymss(annotatedParamListIndex).size
annotated
.paramSymss
.zipWithIndex
Expand Down Expand Up @@ -197,9 +197,9 @@ class UnrollPhaseScala3() extends PluginPhase {
generateSingleForwarder(
defdef,
defdef.symbol.info,
defdef.paramss,
paramClauseIndex,
paramIndex,
paramClauseIndex,
defdef.paramss,
isCaseApply
)
}
Expand Down

0 comments on commit 9264ed6

Please sign in to comment.