Skip to content
This repository has been archived by the owner. It is now read-only.

Commit

Permalink
add sealed traits unmarshallers derivation (oneof support)
Browse files Browse the repository at this point in the history
  • Loading branch information
fomkin committed Feb 11, 2019
1 parent 097651e commit 25f2e54
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 78 deletions.
12 changes: 8 additions & 4 deletions core/src/main/scala/zhukov/Unmarshaller.scala
Expand Up @@ -2,7 +2,7 @@ package zhukov

import zhukov.protobuf.CodedInputStream

sealed trait Unmarshaller[A] {
trait Unmarshaller[A] {

def read(stream: CodedInputStream): A

Expand All @@ -15,8 +15,12 @@ sealed trait Unmarshaller[A] {

object Unmarshaller {

trait LengthDelimitedUnmarshaller[A] extends Unmarshaller[A] { self =>
def map[B](f: A => B): LengthDelimitedUnmarshaller[B] =
/**
* For length-delimited unmarshaller with
* length-reading defined inside CodedInputStream.
*/
trait CodedUnmarshaller[A] extends Unmarshaller[A] { self =>
def map[B](f: A => B): CodedUnmarshaller[B] =
(stream: CodedInputStream) => f(self.read(stream))
}

Expand All @@ -40,5 +44,5 @@ object Unmarshaller {

implicit val int: VarintUnmarshaller[Int] = _.readRawVarint32()
implicit val long: VarintUnmarshaller[Long] = _.readRawVarint64()
implicit val string: LengthDelimitedUnmarshaller[String] = _.readString()
implicit val string: CodedUnmarshaller[String] = _.readString()
}
223 changes: 150 additions & 73 deletions derivation/src/main/scala/zhukov/derivation/ZhukovDerivationMacro.scala
Expand Up @@ -11,107 +11,184 @@ class ZhukovDerivationMacro(val c: blackbox.Context) {

import c.universe._

private val applyName = TermName("apply")

private final case class Field(
index: Int,
originalName: TermName,
varName: TermName,
defaultValueName: TermName,
tpe: Type,
repTpe: Option[Type]
)

private def checkIsClass(x: Symbol) = {
if (x.isClass) x.asClass
else c.abort(c.enclosingPosition, "Zhukov derivation supported only for case classes and sealed traits")
def unmarshallerImpl[T: WeakTypeTag]: Tree = {
val T = weakTypeTag[T].tpe
val ts = T.typeSymbol
if (ts.isClass && ts.asClass.isCaseClass) {
val companion = ts.companion
caseClassUnmarshaller(T, companion)
} else if (ts.isClass && ts.asClass.isTrait && ts.asClass.isSealed) {
sealedTraitUnmarshaller(T, ts.asClass)
}
else c.abort(c.enclosingPosition, "Zhukov derivation is supported only for case classes and sealed traits")
}

private def getOrAbort[T](x: Option[T], pos: Position, msg: String) = x match {
case None => c.abort(pos, msg)
case Some(value) => value
}

def unmarshallerImpl[T: WeakTypeTag]: Tree = {
val T = weakTypeTag[T].tpe
val targetClass = checkIsClass(T.typeSymbol)
val companion = targetClass.companion
val constructor = companion.typeSignature.decl(applyName).asMethod
val params = getOrAbort(constructor.paramLists.headOption, constructor.pos, "Case class should have parameters")
val fields = params.zipWithIndex.map {
case (param, i) =>
val defaultValue = TermName(s"apply$$default$$${i + 1}") // apply$default$1
val tpe = param.typeSignature
if (!companion.typeSignature.decls.exists(_.name == defaultValue))
c.abort(param.pos, "Parameter should have default value")
Field(
index = i + 1,
originalName = param.name.toTermName,
varName = TermName("_" + param.name.toString),
defaultValueName = defaultValue,
tpe = tpe,
repTpe = if (tpe <:< typeOf[Iterable[_]]) Some(tpe.typeArgs.head) else None
)
}
private def inferWireType(tpe: c.universe.Type) = {
if (c.typecheck(q"implicitly[zhukov.Unmarshaller.VarintUnmarshaller[$tpe]]", silent = true).tpe != NoType) VarInt
else if (c.typecheck(q"implicitly[zhukov.Unmarshaller.Fixed32Unmarshaller[$tpe]]", silent = true).tpe != NoType) Fixed32
else if (c.typecheck(q"implicitly[zhukov.Unmarshaller.Fixed64Unmarshaller[$tpe]]", silent = true).tpe != NoType) Fixed64
else if (c.typecheck(q"implicitly[zhukov.Unmarshaller.CodedUnmarshaller[$tpe]]", silent = true).tpe != NoType) Coded
else LengthDelimited
}

val vars = fields.map {
case Field(_, _, name, default, repTpe, Some(tpe)) =>
q"var $name = ${repTpe.typeSymbol.companion}.newBuilder[$tpe] ++= $companion.$default"
case Field(_, _, name, default, _, None) =>
q"var $name = $companion.$default"
}
val applyArgs = fields.map { x =>
x.repTpe match {
case Some(_) => q"${x.originalName} = ${x.varName}.result()"
case None => q"${x.originalName} = ${x.varName}"
}
private def commonUnmarshaller(T: Type, fields: List[Field]): Tree = {
val vars = fields.groupBy(_.varName).mapValues(_.head).collect {
case (name, Field(_, _, _, Some(default), repTpe, Some(tpe), None)) =>
q"var $name = ${repTpe.typeSymbol.companion}.newBuilder[$tpe] ++= $default"
case (name, Field(_, _, _, Some(default), _, None, None)) =>
q"var $name = $default"
case (name, Field(_, _, _, None, _, None, Some(parent))) =>
q"var $name:$parent = null"
}
val cases = fields.flatMap { x =>
val tpe = x.repTpe.getOrElse(x.tpe)
val wireType = {
if (c.typecheck(q"implicitly[zhukov.Unmarshaller.VarintUnmarshaller[$tpe]]", silent = true).tpe != NoType) WireFormat.WIRETYPE_VARINT
else if (c.typecheck(q"implicitly[zhukov.Unmarshaller.Fixed32Unmarshaller[$tpe]]", silent = true).tpe != NoType) WireFormat.WIRETYPE_FIXED32
else if (c.typecheck(q"implicitly[zhukov.Unmarshaller.Fixed64Unmarshaller[$tpe]]", silent = true).tpe != NoType) WireFormat.WIRETYPE_FIXED64
else WireFormat.WIRETYPE_LENGTH_DELIMITED
}
val tag = WireFormat.makeTag(x.index, wireType)
val wireType = inferWireType(tpe)
val tag = WireFormat.makeTag(x.index, wireType.value)
val singleRead = q"implicitly[zhukov.Unmarshaller[$tpe]].read(_stream)"
if (x.repTpe.isEmpty) {
List(cq"$tag => ${x.varName} = $singleRead")
wireType match {
case LengthDelimited =>
List(
cq"""$tag =>
val _length = _stream.readRawVarint32()
val _oldLimit = _stream.pushLimit(_length)
${x.varName} = $singleRead
_stream.checkLastTagWas(0)
_stream.popLimit(_oldLimit)
""")
case _ =>
List(cq"$tag => ${x.varName} = $singleRead")
}
} else {
val cases = List(cq"$tag => ${x.varName} += $singleRead")
val packed = wireType != WireFormat.WIRETYPE_LENGTH_DELIMITED
if (!packed) cases else {
val repTag = WireFormat.makeTag(x.index, WireFormat.WIRETYPE_LENGTH_DELIMITED)
val `case` = cq"""
wireType match {
case VarInt | Fixed32 | Fixed64 => // Packed
val repTag = WireFormat.makeTag(x.index, WireFormat.WIRETYPE_LENGTH_DELIMITED)
val `case` =
cq"""
$repTag =>
val _length = _stream.readRawVarint32()
val _oldLimit = _stream.pushLimit(_length)
while (_stream.getBytesUntilLimit > 0)
${x.varName} += $singleRead
_stream.popLimit(_oldLimit)
"""
`case` :: cases
List(`case`, cq"$tag => ${x.varName} += $singleRead")
case Coded =>
List(cq"$tag => ${x.varName} += $singleRead")
case LengthDelimited =>
List(
cq"""$tag =>
val _length = _stream.readRawVarint32()
val _oldLimit = _stream.pushLimit(_length)
${x.varName} += $singleRead
_stream.checkLastTagWas(0)
_stream.popLimit(_oldLimit)
""")
}
}
}

q"""
new zhukov.Unmarshaller.LengthDelimitedUnmarshaller[$T] {
var _done = false
..$vars
while (!_done) {
val _tag = _stream.readTag()
(_tag: @scala.annotation.switch) match {
case 0 => _done = true
case ..$cases
case _ => _stream.skipField(_tag)
}
}
"""
}

private def sealedTraitUnmarshaller(T: Type, ts: ClassSymbol): Tree = {
val children = ts.knownDirectSubclasses.toList
val termName = TermName("_value")
val fields = children.zipWithIndex.map {
case (x, i) =>
Field(
index = i + 1,
originalName = None,
varName = termName,
default = None,
tpe = x.asClass.toType,
repTpe = None,
parentType = Some(T)
)
}
q"""
new zhukov.Unmarshaller[$T] {
def read(_stream: zhukov.protobuf.CodedInputStream) = {
..${commonUnmarshaller(T, fields)}
_value
}
}
"""
}

private def caseClassUnmarshaller(T: Type, module: Symbol): Tree = {
val constructor = module.typeSignature.decl(applyName).asMethod
val params = getOrAbort(constructor.paramLists.headOption, constructor.pos, "Case class should have parameters")
val fields = params.zipWithIndex.map {
case (param, i) =>
val defaultValue = TermName(s"apply$$default$$${i + 1}") // apply$default$1
val tpe = param.typeSignature
if (!module.typeSignature.decls.exists(_.name == defaultValue))
c.abort(param.pos, "Parameter should have default value")
Field(
index = i + 1,
originalName = Some(param.name.toTermName),
varName = TermName("_" + param.name.toString),
default = Some(q"$module.$defaultValue"),
tpe = tpe,
repTpe =
if (tpe <:< typeOf[Iterable[_]]) Some(tpe.typeArgs.head)
else None,
parentType = None
)
}
val applyArgs = fields.collect {
case Field(_, Some(originalName), varName, _, _, Some(_), _) =>
q"$originalName = $varName.result()"
case Field(_, Some(originalName), varName, _, _, None, _) =>
q"$originalName = $varName"
}
q"""
new zhukov.Unmarshaller[$T] {
def read(_stream: zhukov.protobuf.CodedInputStream) = {
var _done = false
..$vars
while (!_done) {
val _tag = _stream.readTag()
(_tag: @scala.annotation.switch) match {
case 0 => _done = true
case ..$cases
case _ => _stream.skipField(_tag)
}
}
$companion.apply(..$applyArgs)
..${commonUnmarshaller(T, fields)}
$module.apply(..$applyArgs)
}
}
"""
"""
}

private val applyName = TermName("apply")

private sealed abstract class WireType(val value: Int)

private case object VarInt extends WireType(WireFormat.WIRETYPE_VARINT)

private case object Fixed64 extends WireType(WireFormat.WIRETYPE_FIXED64)

private case object Fixed32 extends WireType(WireFormat.WIRETYPE_FIXED32)

private case object Coded extends WireType(WireFormat.WIRETYPE_LENGTH_DELIMITED)

private case object LengthDelimited extends WireType(WireFormat.WIRETYPE_LENGTH_DELIMITED)

private final case class Field(index: Int,
originalName: Option[TermName],
varName: TermName,
default: Option[Tree],
tpe: Type,
repTpe: Option[Type],
parentType: Option[Type])

}
24 changes: 23 additions & 1 deletion derivation/src/test/scala/CompareWithScalapbTest.scala
@@ -1,6 +1,6 @@
import utest._
import zhukov.Unmarshaller
import zhukov.messages.{MessageWithRepeatedString, MessageWithSeq, SimpleMessage}
import zhukov.messages._
import zhukov.derivation.unmarshaller

object CompareWithScalapbTest extends TestSuite {
Expand All @@ -27,6 +27,21 @@ object CompareWithScalapbTest extends TestSuite {
unmarshaller[MessageWithRepeatedString2]
}

sealed trait Expr2

object Expr2 {
//case object Dummy extends Expr2
case class Lit2(value: Int = 0) extends Expr2
case class Add2(lhs: Expr2 = Lit2(), rhs: Expr2 = Lit2()) extends Expr2

implicit val u1: Unmarshaller[Lit2] =
unmarshaller[Lit2]
implicit val u2: Unmarshaller[Add2] =
unmarshaller[Add2]
implicit val u3: Unmarshaller[Expr2] =
unmarshaller[Expr2]
}

val tests = Tests {
"Read messages, serialized with ScalaPB, with zhukov.Unmarshaller" - {
'SimpleMessage - {
Expand Down Expand Up @@ -77,6 +92,13 @@ object CompareWithScalapbTest extends TestSuite {
message.myStrings == message2.myStrings
)
}
"Expr" - {
val l = Expr(Expr.Value.Lit(Lit(2)))
val r = Expr(Expr.Value.Lit(Lit(4)))
val message = Expr(Expr.Value.Add(Add(Some(l), Some(r))))
val message2 = Unmarshaller[Expr2].read(message.toByteArray)
assert(message2 == Expr2.Add2(Expr2.Lit2(2), Expr2.Lit2(4)))
}
}
}
}

0 comments on commit 25f2e54

Please sign in to comment.