Skip to content

Commit

Permalink
Simplify enum derivation in Scala 3 (#453)
Browse files Browse the repository at this point in the history
This PR allows a simple `derives` call to work on Scala 3 enums.

We add a new `macroRWAll` macro, that `derives` now forwards to. This
can be called once on each sealed trait or enum and will automatically
generate the `macroR` calls on all the sub-classes. We then add a pair
of `superTypeReader`/`superTypeWriter` forwarders to allow sub-classes
to make use of super-class readers and writers automatically without any
input from the user.

Note that this approach does not work in Scala 2; while `macroRWAll` can
be defined, `superTypeReader`/`superTypeWriter` seem to cause diverging
implicit expansions. So for now I leave the Scala 2 experience unchanged
  • Loading branch information
lihaoyi committed Mar 11, 2023
1 parent 2dcd043 commit 26cd155
Show file tree
Hide file tree
Showing 13 changed files with 395 additions and 181 deletions.
13 changes: 9 additions & 4 deletions bench/src/Common.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ object Common{
import Generic.ADT
import Hierarchy._
import Recursive._
type Data = ADT[Seq[(Boolean, String, Int, Double)], String, A, LL, ADTc, ADT0]
type Data = ADT[Seq[(Boolean, String, Int, Double)], String, A, LL, Seq[ADTc], ADT0]
val benchmarkSampleData: Seq[Data] = Seq.fill(1000)(ADT(
Vector(
(true, "zero", 0, 0),
Expand Down Expand Up @@ -53,9 +53,14 @@ object Common{
)
)
),
ADTc(
i = 1234567890,
s = "I am cow hear me moo I weigh twice as much as you and I look good on the barbecue"
Seq(
ADTc(i = 1234567890, s = "I am cow hear me moo"),
ADTc(i = 1234567890, s = "I weigh twice as much as you"),
ADTc(i = 1234567890, s = "And I look good on the barbecue"),
ADTc(i = 1234567890, s = "Yoghurt curds cream cheese and butter"),
ADTc(i = 1234567890, s = "Come from liquids from my udder"),
ADTc(i = 1234567890, s = "I can cow I am cow"),
ADTc(i = 1234567890, s = "Hear me moo moo."),
),
ADT0()
))
Expand Down
3 changes: 2 additions & 1 deletion build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ import com.github.lolgab.mill.mima._

val scala212 = "2.12.17"
val scala213 = "2.13.10"
val scala3 = "3.1.3"

val scala3 = "3.2.2"
val scalaJS = "1.13.0"
val scalaNative = "0.4.10"
val acyclic = "0.3.6"
Expand Down
11 changes: 5 additions & 6 deletions implicits/src-3/upickle/implicits/MacroImplicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ trait MacroImplicits extends Readers with Writers with upickle.core.Annotator:
this: upickle.core.Types =>

inline def macroRW[T: ClassTag](using Mirror.Of[T]): ReadWriter[T] =
ReadWriter.join(
macroR[T],
macroW[T]
)
end macroRW
ReadWriter.join(macroR[T], macroW[T])

inline def macroRWAll[T: ClassTag](using Mirror.Of[T]): ReadWriter[T] =
ReadWriter.join(macroRAll[T], macroWAll[T])


// Usually, we would use an extension method to add `derived` to ReadWriter's
Expand Down Expand Up @@ -47,7 +46,7 @@ trait MacroImplicits extends Readers with Writers with upickle.core.Annotator:
// Until that is the case, we'll have to resort to using Scala 2's implicit
// classes to emulate extension methods for deriving readers and writers.
implicit class ReadWriterExtension(r: ReadWriter.type):
inline def derived[T](using Mirror.Of[T], ClassTag[T]): ReadWriter[T] = macroRW[T]
inline def derived[T](using Mirror.Of[T], ClassTag[T]): ReadWriter[T] = macroRWAll[T]
end ReadWriterExtension

end MacroImplicits
15 changes: 13 additions & 2 deletions implicits/src-3/upickle/implicits/Readers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,28 @@ trait ReadersVersionSpecific extends MacrosCommon:
else reader

case m: Mirror.SumOf[T] =>

val readers: List[Reader[_ <: T]] = compiletime.summonAll[Tuple.Map[m.MirroredElemTypes, Reader]]
.toList
.asInstanceOf[List[Reader[_ <: T]]]

Reader.merge[T](readers: _*)
}

inline given[T <: Singleton : Mirror.Of]: Reader[T] = macroR[T]
inline def macroRAll[T](using m: Mirror.Of[T]): Reader[T] = inline m match {
case m: Mirror.ProductOf[T] => macroR[T]
case m: Mirror.SumOf[T] =>
macros.defineEnumReaders[Reader[T], Tuple.Map[m.MirroredElemTypes, Reader]](this)
}

inline given superTypeReader[T: Mirror.ProductOf, V >: T : Reader]: Reader[T] = {
val actual = implicitly[Reader[V]].asInstanceOf[TaggedReader[T]]
val tagName = macros.tagName[T]
new TaggedReader.Leaf(tagName, actual.findReader(tagName))
}

// see comment in MacroImplicits as to why Dotty's extension methods aren't used here
implicit class ReaderExtension(r: Reader.type):
inline def derived[T](using Mirror.Of[T]): Reader[T] = macroR[T]
inline def derived[T](using Mirror.Of[T]): Reader[T] = macroRAll[T]
end ReaderExtension
end ReadersVersionSpecific
12 changes: 10 additions & 2 deletions implicits/src-3/upickle/implicits/Writers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,19 @@ trait WritersVersionSpecific extends MacrosCommon:
Writer.merge[T](writers: _*): Writer[T]
}

inline given[T <: Singleton : Mirror.Of : ClassTag]: Writer[T] = macroW[T]
inline def macroWAll[T: ClassTag](using m: Mirror.Of[T]): Writer[T] = inline m match{
case m: Mirror.ProductOf[T] => macroW[T]
case m: Mirror.SumOf[T] =>
macros.defineEnumWriters[Writer[T], Tuple.Map[m.MirroredElemTypes, Writer]](this)
}

inline given superTypeWriter[T: Mirror.ProductOf : ClassTag, V >: T : Writer]: Writer[T] = {
implicitly[Writer[V]].comap[T](_.asInstanceOf[V])
}

// see comment in MacroImplicits as to why Dotty's extension methods aren't used here
implicit class WriterExtension(r: Writer.type):
inline def derived[T](using Mirror.Of[T], ClassTag[T]): Writer[T] = macroW[T]
inline def derived[T](using Mirror.Of[T], ClassTag[T]): Writer[T] = macroWAll[T]
end WriterExtension

end WritersVersionSpecific
83 changes: 70 additions & 13 deletions implicits/src-3/upickle/implicits/macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package upickle.implicits.macros

import scala.quoted.{ given, _ }
import deriving._, compiletime._

import upickle.implicits.ReadersVersionSpecific
type IsInt[A <: Int] = A

def getDefaultParamsImpl0[T](using Quotes, Type[T]): Map[String, Expr[AnyRef]] =
Expand Down Expand Up @@ -172,18 +172,29 @@ def tagNameImpl[T](using Quotes, Type[T]): Expr[String] =

val sym = TypeTree.of[T].symbol

extractKey(sym) match
case Some(name) => Expr(name)
case None =>
// In Scala 3 enums, we use the short name of each case as the tag, rather
// than the fully-qualified name. We can do this because we know that all
// enum cases are in the same `enum Foo` namespace with distinct short names,
// whereas sealed trait instances could be all over the place with identical
// short names only distinguishable by their prefix.
//
// Harmonizing these two cases further is TBD
if (sym.flags.is(Flags.Enum)) Expr(sym.name.filter(_ != '$'))
else Expr(TypeTree.of[T].tpe.typeSymbol.fullName.filter(_ != '$'))
Expr(
extractKey(sym) match
case Some(name) => name
case None =>
// In Scala 3 enums, we use the short name of each case as the tag, rather
// than the fully-qualified name. We can do this because we know that all
// enum cases are in the same `enum Foo` namespace with distinct short names,
// whereas sealed trait instances could be all over the place with identical
// short names only distinguishable by their prefix.
//
// Harmonizing these two cases further is TBD
if (TypeRepr.of[T] <:< TypeRepr.of[scala.reflect.Enum]) {
// Sometimes .symbol/.typeSymbol gives the wrong thing:
//
// - `.symbol.name` returns `<none>` for `LinkedList.Node[T]`
// - `.typeSymbol` returns `LinkedList` for `LinkedList.End`
//
// so we just mangle `.show` even though it's super gross
TypeRepr.of[T].show.split('.').last.takeWhile(_ != '[')
} else {
TypeTree.of[T].tpe.typeSymbol.fullName.filter(_ != '$')
}
)

inline def isSingleton[T]: Boolean = ${ isSingletonImpl[T] }
def isSingletonImpl[T](using Quotes, Type[T]): Expr[Boolean] =
Expand All @@ -198,3 +209,49 @@ def getSingletonImpl[T](using Quotes, Type[T]): Expr[T] =
case tref: TypeRef => Ref(tref.classSymbol.get.companionModule).asExpr.asInstanceOf[Expr[T]]
case v => '{valueOf[T]}
}


inline def defineEnumReaders[T0, T <: Tuple](prefix: Any): T0 = ${ defineEnumVisitorsImpl[T0, T]('prefix, "macroR") }
inline def defineEnumWriters[T0, T <: Tuple](prefix: Any): T0 = ${ defineEnumVisitorsImpl[T0, T]('prefix, "macroW") }
def defineEnumVisitorsImpl[T0, T <: Tuple](prefix: Expr[Any], macroX: String)(using Quotes, Type[T0], Type[T]): Expr[T0] =
import quotes.reflect._

def handleType(tpe: TypeRepr, name: String, skipTrait: Boolean): Option[(ValDef, Symbol)] = {

val AppliedType(typePrefix, List(arg)) = tpe

if (skipTrait && arg.typeSymbol.flags.is(Flags.Trait)) None
else {
val sym = Symbol.newVal(
Symbol.spliceOwner,
name,
tpe,
Flags.Implicit | Flags.Lazy,
Symbol.noSymbol
)

val macroCall = TypeApply(
Select(prefix.asTerm, prefix.asTerm.tpe.typeSymbol.memberMethod(macroX).head),
List(TypeTree.of(using arg.asType))
)

val newDef = ValDef(sym, Some(macroCall))

Some((newDef, sym))
}
}

def getDefs(t: TypeRepr, defs: List[(ValDef, Symbol)]): List[(ValDef, Symbol)] = {
t match{
case AppliedType(prefix, args) =>
val defAndSymbol = handleType(args(0), "x" + defs.size, skipTrait = true)
getDefs(args(1), defAndSymbol.toList ::: defs)
case _ if t =:= TypeRepr.of[EmptyTuple] => defs
}
}
val subTypeDefs = getDefs(TypeRepr.of[T], Nil)
val topTraitDefs = handleType(TypeRepr.of[T0], "x" + subTypeDefs.size, skipTrait = false)
val allDefs = topTraitDefs.toList ::: subTypeDefs

Block(allDefs.map(_._1), Ident(allDefs.head._2.termRef)).asExprOf[T0]

2 changes: 1 addition & 1 deletion project/build.properties
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@

sbt.version=1.2.8
sbt.version=1.8.2
2 changes: 1 addition & 1 deletion project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
addSbtPlugin("com.lihaoyi" % "scalatex-sbt-plugin" % "0.3.12")
addSbtPlugin("com.lihaoyi" % "scalatex-sbt-plugin" % "0.3.11")
82 changes: 0 additions & 82 deletions upickle/test/src-3/upickle/Derivation.scala

This file was deleted.

Loading

0 comments on commit 26cd155

Please sign in to comment.