Skip to content

Commit

Permalink
fixed 2096 (#2098)
Browse files Browse the repository at this point in the history
Co-authored-by: Darren Gibson <zarthross@users.noreply.github.com>
  • Loading branch information
kailuowang and zarthross committed Mar 2, 2023
1 parent 980d1f7 commit 6bbb7e7
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 11 deletions.
Expand Up @@ -6,13 +6,20 @@ import io.circe.{ Codec, Decoder, Encoder, HCursor, JsonObject }

trait ConfiguredCodec[A] extends Codec.AsObject[A], ConfiguredDecoder[A], ConfiguredEncoder[A]
object ConfiguredCodec:
inline final def derived[A](using conf: Configuration)(using inline mirror: Mirror.Of[A]): ConfiguredCodec[A] =
new ConfiguredCodec[A]:

inline final def derived[A](using conf: Configuration)(using
inline mirror: Mirror.Of[A]
): ConfiguredCodec[A] =
new ConfiguredCodec[A] with SumOrProduct:
val name = constValue[mirror.MirroredLabel]
lazy val elemLabels: List[String] = summonLabels[mirror.MirroredElemLabels]
lazy val elemEncoders: List[Encoder[?]] = summonEncoders[mirror.MirroredElemTypes]
lazy val elemDecoders: List[Decoder[?]] = summonDecoders[mirror.MirroredElemTypes]
lazy val elemDefaults: Default[A] = Predef.summon[Default[A]]
lazy val isSum: Boolean =
inline mirror match
case _: Mirror.ProductOf[A] => false
case _: Mirror.SumOf[A] => true

final def encodeObject(a: A): JsonObject =
inline mirror match
Expand Down
Expand Up @@ -6,6 +6,8 @@ import Predef.genericArrayOps
import cats.data.{ NonEmptyList, Validated }
import io.circe.{ ACursor, Decoder, DecodingFailure, HCursor }
import io.circe.DecodingFailure.Reason.WrongTypeExpectation
import cats.implicits.*
import scala.collection.immutable.Map

trait ConfiguredDecoder[A](using conf: Configuration) extends Decoder[A]:
val name: String
Expand All @@ -14,15 +16,30 @@ trait ConfiguredDecoder[A](using conf: Configuration) extends Decoder[A]:
lazy val elemDefaults: Default[A]
lazy val constructorNames: List[String] = elemLabels.map(conf.transformConstructorNames)

private lazy val decodersDict: Map[String, Decoder[?]] = {
def findDecoderDict(p: (String, Decoder[?])): List[(String, Decoder[?])] =
p._2 match {
case cd: ConfiguredDecoder[?] with SumOrProduct if cd.isSum =>
cd.constructorNames.zip(cd.elemDecoders).flatMap(findDecoderDict)
case _ => List(p)
}
constructorNames.zip(elemDecoders).flatMap(findDecoderDict).toMap
}

private def strictDecodingFailure(c: HCursor, message: String): DecodingFailure =
DecodingFailure(s"Strict decoding $name - $message", c.history)

/** Decodes a class/object/case of a Sum type handling discriminator and strict decoding. */
private def decodeSumElement[R](c: HCursor)(fail: DecodingFailure => R, decode: Decoder[A] => ACursor => R): R =

def fromName(sumTypeName: String, cursor: ACursor): R =
constructorNames.indexOf(sumTypeName) match
case -1 => fail(DecodingFailure(s"type $name has no class/object/case named '$sumTypeName'.", cursor.history))
case index => decode(elemDecoders(index).asInstanceOf[Decoder[A]])(cursor)
decodersDict
.get(sumTypeName)
.fold(
fail(DecodingFailure(s"type $name has no class/object/case named '$sumTypeName'.", cursor.history))
) { decoder =>
decode(decoder.asInstanceOf[Decoder[A]])(cursor)
}

conf.discriminator match
case Some(discriminator) =>
Expand Down Expand Up @@ -156,13 +173,20 @@ trait ConfiguredDecoder[A](using conf: Configuration) extends Decoder[A]:
}

object ConfiguredDecoder:
inline final def derived[A](using conf: Configuration)(using inline mirror: Mirror.Of[A]): ConfiguredDecoder[A] =
new ConfiguredDecoder[A]:
inline final def derived[A](using conf: Configuration)(using
inline mirror: Mirror.Of[A]
): ConfiguredDecoder[A] =
new ConfiguredDecoder[A] with SumOrProduct:
val name = constValue[mirror.MirroredLabel]
lazy val elemLabels: List[String] = summonLabels[mirror.MirroredElemLabels]
lazy val elemDecoders: List[Decoder[?]] = summonDecoders[mirror.MirroredElemTypes]
lazy val elemDefaults: Default[A] = Predef.summon[Default[A]]

lazy val isSum: Boolean =
inline mirror match
case _: Mirror.ProductOf[A] => false
case _: Mirror.SumOf[A] => true

final def apply(c: HCursor): Decoder.Result[A] =
inline mirror match
case product: Mirror.ProductOf[A] => decodeProduct(c, product.fromProduct)
Expand Down
Expand Up @@ -8,8 +8,9 @@ trait ConfiguredEncoder[A](using conf: Configuration) extends Encoder.AsObject[A
lazy val elemLabels: List[String]
lazy val elemEncoders: List[Encoder[?]]

final def encodeElemAt(index: Int, elem: Any, transformName: String => String): (String, Json) =
final def encodeElemAt(index: Int, elem: Any, transformName: String => String): (String, Json) = {
(transformName(elemLabels(index)), elemEncoders(index).asInstanceOf[Encoder[Any]].apply(elem))
}

final def encodeProduct(a: A): JsonObject =
val product = a.asInstanceOf[Product]
Expand All @@ -20,17 +21,31 @@ trait ConfiguredEncoder[A](using conf: Configuration) extends Encoder.AsObject[A

final def encodeSum(index: Int, a: A): JsonObject =
val (constructorName, json) = encodeElemAt(index, a, conf.transformConstructorNames)

conf.discriminator match
case None => JsonObject.singleton(constructorName, json)
case Some(discriminator) =>
json.asObject.getOrElse(JsonObject.empty).add(discriminator, Json.fromString(constructorName))
val jo = json.asObject.getOrElse(JsonObject.empty)
val elemIsSum = elemEncoders(index) match {
case ce: ConfiguredEncoder[?] with SumOrProduct => ce.isSum
case _ => jo.contains(discriminator)
}
if (elemIsSum)
jo
else jo.add(discriminator, Json.fromString(constructorName)) // only add discriminator if elem is a Product

case None => JsonObject.singleton(constructorName, json)

object ConfiguredEncoder:
inline final def derived[A](using conf: Configuration)(using inline mirror: Mirror.Of[A]): ConfiguredEncoder[A] =
new ConfiguredEncoder[A]:
new ConfiguredEncoder[A] with SumOrProduct:
lazy val elemLabels: List[String] = summonLabels[mirror.MirroredElemLabels]
lazy val elemEncoders: List[Encoder[?]] = summonEncoders[mirror.MirroredElemTypes]

lazy val isSum: Boolean =
inline mirror match
case _: Mirror.ProductOf[A] => false
case _: Mirror.SumOf[A] => true

final def encodeObject(a: A): JsonObject =
inline mirror match
case _: Mirror.ProductOf[A] => encodeProduct(a)
Expand Down
@@ -0,0 +1,5 @@
package io.circe.derivation

private[derivation] trait SumOrProduct {
def isSum: Boolean
}
Expand Up @@ -426,3 +426,65 @@ class ConfiguredDerivesSuite extends CirceMunitSuite:
)
)
}

{
given Configuration = Configuration.default.withDiscriminator("type")

sealed trait GrandParent derives ConfiguredCodec
object GrandParent:
given Eq[GrandParent] = Eq.fromUniversalEquals

sealed trait Parent extends GrandParent

case class Child(a: Int, b: String) extends Parent

test("Codec for hierarchy of more than 1 level with discriminator should encode and decode correctly") {
val child: GrandParent = Child(1, "a")
val json = Encoder.AsObject[GrandParent].apply(child)
val result = Decoder[GrandParent].decodeJson(json)
assert(result === Right(child), result)
}
}

{
sealed trait GreatGrandParent
object GreatGrandParent:
given Eq[GreatGrandParent] = Eq.fromUniversalEquals[GreatGrandParent]

sealed trait GrandParent extends GreatGrandParent
case class Uncle(Child: Int)
extends GrandParent // The field name, `Child` matches a existing case class in the hierarchy and is important for the tests.
sealed trait Parent extends GrandParent

case class Child(a: Int, b: String) extends Parent

test(
"Codec for hierarchy of more than 2 level with discriminator should encode and decode correctly, even if a parent's sibling has a field with the same name as a Child type"
) {
given Configuration = Configuration.default.withDiscriminator("type")
given Codec.AsObject[GreatGrandParent] = ConfiguredCodec.derived[GreatGrandParent]

val child: GrandParent = Child(1, "a")
val json = Encoder.AsObject[GreatGrandParent].apply(child)
val result = Decoder[GreatGrandParent].decodeJson(json)
assert(result === Right(child), result)
}

}

{
given Configuration = Configuration.default.withDiscriminator("type")

sealed trait Tree derives ConfiguredCodec;
case class Branch(l: Tree, r: Tree) extends Tree;
case object Leaf extends Tree;
object Tree:
given Eq[Tree] = Eq.fromUniversalEquals[Tree]

test("Codec for recursive type should encode and decode correctly") {
val tree: Tree = Branch(Branch(Leaf, Leaf), Leaf)
val json = Encoder.AsObject[Tree].apply(tree)
val result = Decoder[Tree].decodeJson(json)
assert(result === Right(tree), result)
}
}

0 comments on commit 6bbb7e7

Please sign in to comment.