Skip to content

[JAVA-3814] Supports BsonIgnore annotation for scala driver #584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package org.mongodb.scala.bson.annotations

import scala.annotation.StaticAnnotation

/**
* Annotation to ignore a property.
*/
case class BsonIgnore() extends StaticAnnotation
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
package org.mongodb.scala.bson.codecs.macrocodecs

import scala.reflect.macros.whitebox

import org.bson.codecs.Codec
import org.bson.codecs.configuration.CodecRegistry

import org.mongodb.scala.bson.annotations.BsonProperty
import org.mongodb.scala.bson.annotations.{ BsonIgnore, BsonProperty }

private[codecs] object CaseClassCodec {

Expand Down Expand Up @@ -157,6 +155,36 @@ private[codecs] object CaseClassCodec {
.toMap
}

val ignoredFields: Map[Type, Seq[(TermName, Tree)]] = {
knownTypes.map { tpe =>
if (!isCaseClass(tpe)) {
(tpe, Nil)
} else {
val constructor = tpe.decl(termNames.CONSTRUCTOR)
if (!constructor.isMethod) c.abort(c.enclosingPosition, "No constructor, unsupported class type")

val defaults = constructor.asMethod.paramLists.head
.map(_.asTerm)
.zipWithIndex
.filter(_._1.annotations.exists(_.tree.tpe == typeOf[BsonIgnore]))
.map {
case (p, i) =>
if (p.isParamWithDefault) {
val getterName = TermName("apply$default$" + (i + 1))
p.name -> q"${tpe.typeSymbol.companion}.$getterName"
} else {
c.abort(
c.enclosingPosition,
s"Field [${p.name}] with BsonIgnore annotation must have a default value"
)
}
}

tpe -> defaults
}
}.toMap
}

// Data converters
def keyName(t: Type): Literal = Literal(Constant(t.typeSymbol.name.decodedName.toString))
def keyNameTerm(t: TermName): Literal = Literal(classAnnotatedFieldsMap.getOrElse(t, Constant(t.toString)))
Expand Down Expand Up @@ -284,12 +312,14 @@ private[codecs] object CaseClassCodec {
* @param fields the list of fields
* @return the tree that writes the case class fields
*/
def writeClassValues(fields: List[(TermName, Type)]): List[Tree] = {
fields.map({
case (name, f) =>
val key = keyNameTerm(name)
f match {
case optional if isOption(optional) => q"""
def writeClassValues(fields: List[(TermName, Type)], ignoredFields: Seq[(TermName, Tree)]): List[Tree] = {
fields
.filterNot { case (name, _) => ignoredFields.exists { case (iname, _) => name == iname } }
.map({
case (name, f) =>
val key = keyNameTerm(name)
f match {
case optional if isOption(optional) => q"""
val localVal = instanceValue.$name
if (localVal.isDefined) {
writer.writeName($key)
Expand All @@ -298,13 +328,13 @@ private[codecs] object CaseClassCodec {
writer.writeName($key)
this.writeFieldValue($key, writer, this.bsonNull, encoderContext)
}"""
case _ => q"""
case _ => q"""
val localVal = instanceValue.$name
writer.writeName($key)
this.writeFieldValue($key, writer, localVal, encoderContext)
"""
}
})
}
})
}

/*
Expand All @@ -314,7 +344,7 @@ private[codecs] object CaseClassCodec {
val cases: Seq[Tree] = {
fields.map(field => cq""" ${keyName(field._1)} =>
val instanceValue = value.asInstanceOf[${field._1}]
..${writeClassValues(field._2)}""").toSeq
..${writeClassValues(field._2, ignoredFields(field._1))}""").toSeq
}

q"""
Expand All @@ -325,23 +355,29 @@ private[codecs] object CaseClassCodec {
"""
}

def fieldSetters(fields: List[(TermName, Type)]) = {
def fieldSetters(fields: List[(TermName, Type)], ignoredFields: Seq[(TermName, Tree)]) = {
fields.map({
case (name, f) =>
val key = keyNameTerm(name)
val missingField = Literal(Constant(s"Missing field: $key"))
f match {
case optional if isOption(optional) =>
q"$name = (if (fieldData.contains($key)) Option(fieldData($key)) else None).asInstanceOf[$f]"
case _ =>
q"""$name = fieldData.getOrElse($key, throw new BsonInvalidOperationException($missingField)).asInstanceOf[$f]"""

ignoredFields.find { case (iname, _) => name == iname }.map(_._2) match {
case Some(default) =>
q"$name = $default"
case None =>
f match {
case optional if isOption(optional) =>
q"$name = (if (fieldData.contains($key)) Option(fieldData($key)) else None).asInstanceOf[$f]"
case _ =>
q"""$name = fieldData.getOrElse($key, throw new BsonInvalidOperationException($missingField)).asInstanceOf[$f]"""
}
}
})
}

def getInstance = {
val cases = knownTypes.map { st =>
cq"${keyName(st)} => new $st(..${fieldSetters(fields(st))})"
cq"${keyName(st)} => new $st(..${fieldSetters(fields(st), ignoredFields(st))})"
} :+ cq"""_ => throw new BsonInvalidOperationException("Unexpected class type: " + className)"""
q"className match { case ..$cases }"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.bson.codecs.{ Codec, DecoderContext, EncoderContext }
import org.bson.io.{ BasicOutputBuffer, ByteBufferBsonInput, OutputBuffer }
import org.bson.types.ObjectId
import org.mongodb.scala.bson.BaseSpec
import org.mongodb.scala.bson.annotations.BsonProperty
import org.mongodb.scala.bson.annotations.{ BsonIgnore, BsonProperty }
import org.mongodb.scala.bson.codecs.Macros.{ createCodecProvider, createCodecProviderIgnoreNone }
import org.mongodb.scala.bson.codecs.Registry.DEFAULT_CODEC_REGISTRY
import org.mongodb.scala.bson.collection.immutable.Document
Expand All @@ -43,6 +43,7 @@ class MacrosSpec extends BaseSpec {
case class SeqOfStrings(name: String, value: Seq[String])
case class RecursiveSeq(name: String, value: Seq[RecursiveSeq])
case class AnnotatedClass(@BsonProperty("annotated_name") name: String)
case class IgnoredFieldClass(name: String, @BsonIgnore meta: String = "ignored_default")

case class Binary(binary: Array[Byte]) {

Expand Down Expand Up @@ -103,6 +104,11 @@ class MacrosSpec extends BaseSpec {
case class Branch(@BsonProperty("l1") b1: Tree, @BsonProperty("r1") b2: Tree, value: Int) extends Tree
case class Leaf(value: Int) extends Tree

sealed trait WithIgnored
case class MetaIgnoredField(data: String, @BsonIgnore meta: Seq[String] = Vector("ignore_me")) extends WithIgnored
case class LeafCountIgnoredField(branchCount: Int, @BsonIgnore leafCount: Int = 100) extends WithIgnored
case class ContainsIgnoredField(list: Seq[WithIgnored])

case class ContainsADT(name: String, tree: Tree)
case class ContainsSeqADT(name: String, trees: Seq[Tree])
case class ContainsNestedSeqADT(name: String, trees: Seq[Seq[Tree]])
Expand Down Expand Up @@ -270,6 +276,23 @@ class MacrosSpec extends BaseSpec {
)
}

it should "be able to ignore fields" in {
roundTrip(
IgnoredFieldClass("Bob", "singer"),
IgnoredFieldClass("Bob"),
"""{name: "Bob"}""",
classOf[IgnoredFieldClass]
)

roundTrip(
ContainsIgnoredField(Vector(MetaIgnoredField("Bob", List("singer")), LeafCountIgnoredField(1, 10))),
ContainsIgnoredField(Vector(MetaIgnoredField("Bob"), LeafCountIgnoredField(1))),
"""{"list" : [{"_t" : "MetaIgnoredField", "data" : "Bob" }, {"_t" : "LeafCountIgnoredField", "branchCount": 1}]}""",
classOf[ContainsIgnoredField],
classOf[WithIgnored]
)
}

it should "be able to round trip polymorphic nested case classes in a sealed class" in {
roundTrip(
ContainsSealedClass(List(SealedClassA("test"), SealedClassB(12))),
Expand Down Expand Up @@ -657,6 +680,15 @@ class MacrosSpec extends BaseSpec {
roundTripCodec(value, Document(expected), codec)
}

def roundTrip[T](value: T, decodedValue: T, expected: String, provider: CodecProvider, providers: CodecProvider*)(
implicit ct: ClassTag[T]
): Unit = {
val codecProviders: util.List[CodecProvider] = (provider +: providers).asJava
val registry = CodecRegistries.fromRegistries(CodecRegistries.fromProviders(codecProviders), DEFAULT_CODEC_REGISTRY)
val codec = registry.get(ct.runtimeClass).asInstanceOf[Codec[T]]
roundTripCodec(value, decodedValue, Document(expected), codec)
}

def roundTripCodec[T](value: T, expected: Document, codec: Codec[T]): Unit = {
val encoded = encode(codec, value)
val actual = decode(documentCodec, encoded)
Expand All @@ -666,6 +698,18 @@ class MacrosSpec extends BaseSpec {
assert(roundTripped == value, s"Round Tripped case class: ($roundTripped) did not equal the original: ($value)")
}

def roundTripCodec[T](value: T, decodedValue: T, expected: Document, codec: Codec[T]): Unit = {
val encoded = encode(codec, value)
val actual = decode(documentCodec, encoded)
assert(expected == actual, s"Encoded document: (${actual.toJson()}) did not equal: (${expected.toJson()})")

val roundTripped = decode(codec, encode(codec, value))
assert(
roundTripped == decodedValue,
s"Round Tripped case class: ($roundTripped) did not equal the expected: ($decodedValue)"
)
}

def encode[T](codec: Codec[T], value: T): OutputBuffer = {
val buffer = new BasicOutputBuffer()
val writer = new BsonBinaryWriter(buffer)
Expand Down