Skip to content

Commit

Permalink
Add readers for some well-known types
Browse files Browse the repository at this point in the history
  • Loading branch information
jodersky committed Mar 9, 2023
1 parent 562c711 commit 49b2303
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 53 deletions.
148 changes: 97 additions & 51 deletions scalapb-ujson/src/scalapb/ujson/JsonFormat.scala
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
package scalapb.ujson

import upickle.core.Visitor
import upickle.core.ArrVisitor
import upickle.core.ObjVisitor
import scalapb.descriptors as sd
import ujson.Transformer
import scalapb.GeneratedMessage

import com.google.protobuf.timestamp.Timestamp
import com.google.protobuf.duration.Duration
import com.google.protobuf.field_mask.FieldMask
import com.google.protobuf.timestamp.Timestamp
import com.google.protobuf.wrappers
import scalapb.descriptors as sd
import scalapb.FieldMaskUtil
import scalapb.GeneratedMessage
import ujson.Transformer
import upickle.core.ArrVisitor
import upickle.core.ObjVisitor
import upickle.core.Visitor

class JsonFormatException(msg: String, cause: Exception = null)
class JsonFormatException(msg: String, cause: Throwable = null)
extends Exception(msg, cause)

// something went wrong reading JSON as a protobuf message
class JsonReadException(
val message: String,
val position: Int,
cause: Exception = null
cause: Throwable = null
) extends JsonFormatException(s"$message (position: $position)", cause)

object JsonFormat:
Expand All @@ -29,6 +30,16 @@ object JsonFormat:
final val DurationDescriptor = Duration.scalaDescriptor
final val FieldMaskDescriptor = FieldMask.scalaDescriptor

final val DoubleValueDescriptor = wrappers.DoubleValue.scalaDescriptor
final val FloatValueDescriptor = wrappers.FloatValue.scalaDescriptor
final val Int32ValueDescriptor = wrappers.Int32Value.scalaDescriptor
final val Int64ValueDescriptor = wrappers.Int64Value.scalaDescriptor
final val UInt32ValueDescriptor = wrappers.UInt32Value.scalaDescriptor
final val UInt64ValueDescriptor = wrappers.UInt64Value.scalaDescriptor
final val BoolValueDescriptor = wrappers.BoolValue.scalaDescriptor
final val BytesValueDescriptor = wrappers.BytesValue.scalaDescriptor
final val StringValueDescriptor = wrappers.StringValue.scalaDescriptor

/** `this_is_snake_case => thisIsCamelCase` */
def camelify(snake: String): String =
val camel = new StringBuilder
Expand Down Expand Up @@ -149,29 +160,58 @@ class JsonFormat(

descriptor match
case JsonFormat.TimestampDescriptor =>
val seconds = fields(descriptor.findFieldByNumber(0).get).asInstanceOf[sd.PLong]
val nanos = fields(descriptor.findFieldByNumber(1).get).asInstanceOf[sd.PInt]
// PEmpty in a non-message field means that we're recursively completing with default values
(fields(descriptor.findFieldByNumber(1).get): @unchecked) match
case sd.PEmpty => out.visitString("1970-01-01T00:00:00Z", -1)
case seconds: sd.PLong =>
val nanos = fields(descriptor.findFieldByNumber(2).get).asInstanceOf[sd.PInt]

// TODO: not ideal that we need to rebuild a Scala class instance from a PValue
val str = Timestamps.writeTimestamp(Timestamp(seconds.value, nanos.value))
out.visitString(str, -1)
// TODO: not ideal that we need to rebuild a Scala class instance from a PValue
val str = Timestamps.writeTimestamp(Timestamp(seconds.value, nanos.value))
out.visitString(str, -1)

case JsonFormat.DurationDescriptor =>
val seconds = fields(descriptor.findFieldByNumber(0).get).asInstanceOf[sd.PLong]
val nanos = fields(descriptor.findFieldByNumber(1).get).asInstanceOf[sd.PInt]
(fields(descriptor.findFieldByNumber(1).get): @unchecked) match
case sd.PEmpty => out.visitString("0s", -1)
case seconds: sd.PLong =>
val nanos = fields(descriptor.findFieldByNumber(2).get).asInstanceOf[sd.PInt]

// TODO: not ideal that we need to rebuild a Scala class instance from a PValue
val str = Durations.writeDuration(Duration(seconds.value, nanos.value))
out.visitString(str, -1)
// TODO: not ideal that we need to rebuild a Scala class instance from a PValue
val str = Durations.writeDuration(Duration(seconds.value, nanos.value))
out.visitString(str, -1)

case JsonFormat.FieldMaskDescriptor =>
val paths = fields(descriptor.findFieldByNumber(0).get).asInstanceOf[sd.PRepeated]

// TODO: not ideal that we need to rebuild a Scala class instance from a PValue
val str = FieldMaskUtil.toJsonString(
FieldMask(paths.value.map(_.asInstanceOf[sd.PString].value))
)
out.visitString(str, -1)
(fields(descriptor.findFieldByNumber(1).get): @unchecked) match
case sd.PEmpty => out.visitString("", -1)
case paths: sd.PRepeated =>
// TODO: not ideal that we need to rebuild a Scala class instance from a PValue
val str = FieldMaskUtil.toJsonString(
FieldMask(paths.value.map(_.asInstanceOf[sd.PString].value))
)
out.visitString(str, -1)

case
JsonFormat.DoubleValueDescriptor |
JsonFormat.FloatValueDescriptor |
JsonFormat.Int32ValueDescriptor |
JsonFormat.Int64ValueDescriptor |
JsonFormat.UInt32ValueDescriptor |
JsonFormat.UInt64ValueDescriptor |
JsonFormat.BoolValueDescriptor |
JsonFormat.BytesValueDescriptor |
JsonFormat.StringValueDescriptor =>

val fd = descriptor.findFieldByNumber(1).get
val pv = fields(fd)

if pv != sd.PEmpty then
writePrimitive(
out,
fd,
fields(fd)
)
else
out.visitNull(-1)

case _ =>
val objVisitor = out.visitObject(
Expand All @@ -193,37 +233,43 @@ class JsonFormat(
value: sd.PValue
): Unit = value match
case sd.PEmpty =>
if includeDefaultValueFields && fd.containingOneof == None then
if includeDefaultValueFields then
out.visitKeyValue(out.visitKey(-1).visitString(jsonName(fd), -1))

// This is a bit of a trick: in ScalaPB, PEmpty is only used for message
// fields, so this check would be redundant. However, in order to avoid
// code duplication, we recursively call this function with PEmpty
// meaning the absence of a value, even for primitive types. This allows
// us to recursively construct nested default messages, without the need
// of duplicating logic in a separate function.
if fd.protoType.isTypeMessage then
val sd.ScalaType.Message(md) = (fd.scalaType: @unchecked)

if fd.containingOneof.isDefined then
out.narrow.visitValue(
writeMessage(
out.subVisitor,
md,
sd.PMessage(
md.fields.map(f => f -> sd.PEmpty).toMap
) // here PEmpty is not necessarily a missing *message* type
),
out.subVisitor.visitNull(-1),
-1
)
else
out.narrow.visitValue(
writePrimitive(
out.subVisitor,
fd,
JsonFormat.defaultPrimitiveValue(fd)
),
-1
)
// This is a bit of a trick: in ScalaPB, PEmpty is only used for message
// fields, so this check would be redundant. However, in order to avoid
// code duplication, we recursively call this function with PEmpty
// meaning the absence of a value, even for primitive types. This allows
// us to recursively construct nested default messages, without the need
// of duplicating logic in a separate function.
if fd.protoType.isTypeMessage then
val sd.ScalaType.Message(md) = (fd.scalaType: @unchecked)

out.narrow.visitValue(
writeMessage(
out.subVisitor,
md,
sd.PMessage(
md.fields.map(f => f -> sd.PEmpty).toMap
) // here PEmpty is not necessarily a missing *message* type
),
-1
)
else
out.narrow.visitValue(
writePrimitive(
out.subVisitor,
fd,
JsonFormat.defaultPrimitiveValue(fd)
),
-1
)
case sd.PRepeated(xs) =>
if xs.nonEmpty || includeDefaultValueFields then
out.visitKeyValue(out.visitKey(-1).visitString(jsonName(fd), -1))
Expand Down
28 changes: 27 additions & 1 deletion scalapb-ujson/src/scalapb/ujson/readers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import upickle.core.Visitor
import upickle.core.ArrVisitor
import upickle.core.ObjVisitor
import scalapb.descriptors as sd
import scalapb.FieldMaskUtil

class Reader(md: sd.Descriptor) extends SimpleVisitor[sd.PValue, sd.PMessage]:
override val expectedMsg: String = "expected JSON object"
Expand Down Expand Up @@ -113,13 +114,38 @@ private class FieldVisitor(var fd: sd.FieldDescriptor, inArray: Boolean = false)
ed.values.find(_.name == s.toString) match
case None => sd.PEmpty // ignore unknown value
case Some(ev) => sd.PEnum(ev)
else if fd.protoType.isTypeString then sd.PString(s.toString())
else if fd.protoType.isTypeString then
sd.PString(s.toString())
else if fd.protoType.isTypeBytes then
sd.PByteString(
com.google.protobuf.ByteString.copyFrom(
java.util.Base64.getDecoder().decode(s.toString)
)
)
else if fd.protoType.isTypeMessage then
val sd.ScalaType.Message(d) = (fd.scalaType: @unchecked)

def specialParse(tpe: String)(action: => sd.PValue) =
try
action
catch
case t: Throwable =>
throw JsonReadException(s"error for protobuf fiel '${fd.fullName}', parsing string as $tpe: ${t.getMessage}", index, t)

d match
case JsonFormat.TimestampDescriptor =>
specialParse("timestamp") {
Timestamps.parseTimestamp(s.toString).toPMessage
}
case JsonFormat.DurationDescriptor =>
specialParse("duration") {
Durations.parseDuration(s.toString).toPMessage
}
case JsonFormat.FieldMaskDescriptor =>
specialParse("fieldmask") {
FieldMaskUtil.fromJsonString(s.toString).toPMessage
}
case _ => unexpectedType("string", index)
else unexpectedType("string", index)

override def visitNull(index: Int) =
Expand Down
14 changes: 14 additions & 0 deletions scalapb-ujson/test/protobuf/protos.proto
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,18 @@ message Message {
}

bytes data = 100;

optional int32 optint = 16;
}

import "google/protobuf/timestamp.proto";
import "google/protobuf/duration.proto";
import "google/protobuf/wrappers.proto";
import "google/protobuf/field_mask.proto";

message SpecialFormats {
google.protobuf.Timestamp ts = 1;
google.protobuf.Duration duration = 2;
google.protobuf.Int32Value wrapper = 3;
google.protobuf.FieldMask fm = 4;
}
56 changes: 55 additions & 1 deletion scalapb-ujson/test/src/RwTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ object RwTests extends TestSuite:
| "repeated_nested": [],
| "messages": {},
| "nested_map": {},
| "data": ""
| "either_1": null,
| "either_2": null,
| "either_3": null,
| "either_4": null,
| "data": "",
| "optint": null
|}
|""".stripMargin
assertEqual(fmt, msg, expected, checkRead = false) // can't easily check reads with defaults included
Expand Down Expand Up @@ -256,4 +261,53 @@ object RwTests extends TestSuite:
)
assertEqual(fmt, msg, """{"data":"aGVsbG8sIHdvcmxk"}""")
}
test("optional") {
val fmt = JsonFormat(
includeDefaultValueFields = false
)
val msg = protos.Message(optint = Some(0))
assertEqual(fmt, msg, """{"optint": 0}""")
}
test("specials") {
test("no defaults") {
val fmt = JsonFormat(
includeDefaultValueFields = false
)

val msg = protos.SpecialFormats()
assertEqual(fmt, msg, """{}""")
}
test("defaults") {
val fmt = JsonFormat(
includeDefaultValueFields = true
)

val msg = protos.SpecialFormats()
assertEqual(fmt, msg, """{"ts":"1970-01-01T00:00:00Z","duration":"0s","wrapper": null, "fm":""}""", checkRead = false)
}
test("values") {
val fmt = JsonFormat(
includeDefaultValueFields = false
)

val msg = protos.SpecialFormats()
.withTs(
com.google.protobuf.timestamp.Timestamp(1678372591, 42)
)
.withDuration(
com.google.protobuf.duration.Duration(1000, 42)
)
.withFm(
com.google.protobuf.field_mask.FieldMask(Seq("a", "b.c"))
)
val expected =
"""|{
| "ts":"2023-03-09T14:36:31.000000042Z",
| "duration":"1000.000000042s",
| "fm":"a,b.c"
|}
|""".stripMargin
assertEqual(fmt, msg, expected)
}
}
}

0 comments on commit 49b2303

Please sign in to comment.