Skip to content

Commit

Permalink
more tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
lihaoyi committed Feb 11, 2024
1 parent ed92a7f commit 3b91875
Showing 1 changed file with 23 additions and 24 deletions.
47 changes: 23 additions & 24 deletions unroll/plugin/src-3/UnrollPhaseScala3.scala
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,6 @@ class UnrollPhaseScala3() extends PluginPhase {
newDefDef
}

def isCaseFromProduct(t: Tree)(using Context) = t match{
case defdef: DefDef => defdef.name.toString == "fromProduct" && defdef.symbol.owner.companionClass.is(CaseClass)
case _ => false
}

def generateFromProduct(startParamIndex: Int, paramCount: Int, defdef: DefDef)(using Context) = {
cpy.DefDef(defdef)(
name = defdef.name,
Expand Down Expand Up @@ -152,7 +147,7 @@ class UnrollPhaseScala3() extends PluginPhase {
).setDefTree
}

def generateSyntheticDefs(tree: Tree)(using Context): Seq[Tree] = tree match{
def generateSyntheticDefs(tree: Tree)(using Context): (Option[Symbol], Seq[Tree]) = tree match{
case defdef: DefDef if defdef.paramss.nonEmpty =>
import dotty.tools.dotc.core.NameOps.isConstructorName

Expand All @@ -162,7 +157,7 @@ class UnrollPhaseScala3() extends PluginPhase {
val isCaseApply =
defdef.name.toString == "apply" && defdef.symbol.owner.companionClass.is(CaseClass)

val isCaseFromProduct = this.isCaseFromProduct(defdef)
val isCaseFromProduct = defdef.name.toString == "fromProduct" && defdef.symbol.owner.companionClass.is(CaseClass)

val annotated =
if (isCaseCopy) defdef.symbol.owner.primaryConstructor
Expand All @@ -172,44 +167,48 @@ class UnrollPhaseScala3() extends PluginPhase {

val firstValueParamClauseIndex = annotated.paramSymss.indexWhere(!_.headOption.exists(_.isType))

if (firstValueParamClauseIndex == -1) Nil
if (firstValueParamClauseIndex == -1) (None, Nil)
else {
val paramCount = annotated.paramSymss(firstValueParamClauseIndex).size
annotated
.paramSymss(firstValueParamClauseIndex)
.indexWhere(_.annotations.exists(_.symbol.fullName.toString == "unroll.Unroll")) match{
case -1 => Nil
case -1 => (None, Nil)
case startParamIndex =>
if (isCaseFromProduct) {
Seq(generateFromProduct(startParamIndex, paramCount, defdef))
(Some(defdef.symbol), Seq(generateFromProduct(startParamIndex, paramCount, defdef)))
} else {
for (paramIndex <- Range(startParamIndex, paramCount)) yield {
generateSingleForwarder(
defdef,
defdef.symbol.info,
defdef.paramss,
firstValueParamClauseIndex,
paramIndex,
isCaseApply
)
}
(
None,
for (paramIndex <- Range(startParamIndex, paramCount)) yield {
generateSingleForwarder(
defdef,
defdef.symbol.info,
defdef.paramss,
firstValueParamClauseIndex,
paramIndex,
isCaseApply
)
}
)
}
}
}
case _ => Nil
case _ => (None, Nil)
}

override def transformTemplate(tmpl: tpd.Template)(using Context): tpd.Tree = {

val (removed0, generatedDefs) = tmpl.body.map(generateSyntheticDefs).unzip
val (None, generatedConstr) = generateSyntheticDefs(tmpl.constr)
val removed = removed0.flatten
super.transformTemplate(
cpy.Template(tmpl)(
tmpl.constr,
tmpl.parents,
tmpl.derived,
tmpl.self,
tmpl.body.filter(!this.isCaseFromProduct(_)) ++
tmpl.body.flatMap(generateSyntheticDefs) ++
generateSyntheticDefs(tmpl.constr)
tmpl.body.filter(t => !removed.contains(t.symbol)) ++ generatedDefs.flatten ++ generatedConstr
)
)
}
Expand Down

0 comments on commit 3b91875

Please sign in to comment.