Skip to content

Commit

Permalink
Support discriminators not being the first field when decoding (#1324)
Browse files Browse the repository at this point in the history
  • Loading branch information
rozza committed Mar 15, 2024
1 parent 375bbd6 commit d9bedd4
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 34 deletions.
100 changes: 77 additions & 23 deletions bson-kotlinx/src/main/kotlin/org/bson/codecs/kotlinx/BsonDecoder.kt
Expand Up @@ -31,6 +31,7 @@ import kotlinx.serialization.modules.SerializersModule
import org.bson.AbstractBsonReader
import org.bson.BsonInvalidOperationException
import org.bson.BsonReader
import org.bson.BsonReaderMark
import org.bson.BsonType
import org.bson.BsonValue
import org.bson.codecs.BsonValueCodec
Expand Down Expand Up @@ -68,6 +69,20 @@ internal open class DefaultBsonDecoder(
val validKeyKinds = setOf(PrimitiveKind.STRING, PrimitiveKind.CHAR, SerialKind.ENUM)
val bsonValueCodec = BsonValueCodec()
const val UNKNOWN_INDEX = -10
fun validateCurrentBsonType(
reader: AbstractBsonReader,
expectedType: BsonType,
descriptor: SerialDescriptor,
actualType: (descriptor: SerialDescriptor) -> String = { it.kind.toString() }
) {
reader.currentBsonType?.let {
if (it != expectedType) {
throw SerializationException(
"Invalid data for `${actualType(descriptor)}` expected a bson " +
"${expectedType.name.lowercase()} found: ${reader.currentBsonType}")
}
}
}
}

private fun initElementMetadata(descriptor: SerialDescriptor) {
Expand Down Expand Up @@ -119,29 +134,14 @@ internal open class DefaultBsonDecoder(

@Suppress("ReturnCount")
override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder {
when (descriptor.kind) {
is StructureKind.LIST -> {
reader.readStartArray()
return BsonArrayDecoder(reader, serializersModule, configuration)
}
is PolymorphicKind -> {
reader.readStartDocument()
return PolymorphicDecoder(reader, serializersModule, configuration)
}
return when (descriptor.kind) {
is StructureKind.LIST -> BsonArrayDecoder(descriptor, reader, serializersModule, configuration)
is PolymorphicKind -> PolymorphicDecoder(descriptor, reader, serializersModule, configuration)
is StructureKind.CLASS,
StructureKind.OBJECT -> {
val current = reader.currentBsonType
if (current == null || current == BsonType.DOCUMENT) {
reader.readStartDocument()
}
}
is StructureKind.MAP -> {
reader.readStartDocument()
return BsonDocumentDecoder(reader, serializersModule, configuration)
}
StructureKind.OBJECT -> BsonDocumentDecoder(descriptor, reader, serializersModule, configuration)
is StructureKind.MAP -> MapDecoder(descriptor, reader, serializersModule, configuration)
else -> throw SerializationException("Primitives are not supported at top-level")
}
return DefaultBsonDecoder(reader, serializersModule, configuration)
}

override fun endStructure(descriptor: SerialDescriptor) {
Expand Down Expand Up @@ -194,10 +194,17 @@ internal open class DefaultBsonDecoder(

@OptIn(ExperimentalSerializationApi::class)
private class BsonArrayDecoder(
descriptor: SerialDescriptor,
reader: AbstractBsonReader,
serializersModule: SerializersModule,
configuration: BsonConfiguration
) : DefaultBsonDecoder(reader, serializersModule, configuration) {

init {
validateCurrentBsonType(reader, BsonType.ARRAY, descriptor)
reader.readStartArray()
}

private var index = 0
override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
val nextType = reader.readBsonType()
Expand All @@ -208,18 +215,46 @@ private class BsonArrayDecoder(

@OptIn(ExperimentalSerializationApi::class)
private class PolymorphicDecoder(
descriptor: SerialDescriptor,
reader: AbstractBsonReader,
serializersModule: SerializersModule,
configuration: BsonConfiguration
) : DefaultBsonDecoder(reader, serializersModule, configuration) {
private var index = 0
private var mark: BsonReaderMark?

override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T =
deserializer.deserialize(DefaultBsonDecoder(reader, serializersModule, configuration))
init {
mark = reader.mark
validateCurrentBsonType(reader, BsonType.DOCUMENT, descriptor) { it.serialName }
reader.readStartDocument()
}

override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T {
mark?.let {
it.reset()
mark = null
}
return deserializer.deserialize(DefaultBsonDecoder(reader, serializersModule, configuration))
}

override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
var found = false
return when (index) {
0 -> index++
0 -> {
while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) {
if (reader.readName() == configuration.classDiscriminator) {
found = true
break
}
reader.skipValue()
}
if (!found) {
throw SerializationException(
"Missing required discriminator field `${configuration.classDiscriminator}` " +
"for polymorphic class: `${descriptor.serialName}`.")
}
index++
}
1 -> index++
else -> DECODE_DONE
}
Expand All @@ -228,6 +263,20 @@ private class PolymorphicDecoder(

@OptIn(ExperimentalSerializationApi::class)
private class BsonDocumentDecoder(
descriptor: SerialDescriptor,
reader: AbstractBsonReader,
serializersModule: SerializersModule,
configuration: BsonConfiguration
) : DefaultBsonDecoder(reader, serializersModule, configuration) {
init {
validateCurrentBsonType(reader, BsonType.DOCUMENT, descriptor) { it.serialName }
reader.readStartDocument()
}
}

@OptIn(ExperimentalSerializationApi::class)
private class MapDecoder(
descriptor: SerialDescriptor,
reader: AbstractBsonReader,
serializersModule: SerializersModule,
configuration: BsonConfiguration
Expand All @@ -236,6 +285,11 @@ private class BsonDocumentDecoder(
private var index = 0
private var isKey = false

init {
validateCurrentBsonType(reader, BsonType.DOCUMENT, descriptor)
reader.readStartDocument()
}

override fun decodeString(): String {
return if (isKey) {
reader.readName()
Expand Down
Expand Up @@ -16,14 +16,19 @@
package org.bson.codecs.kotlinx

import com.mongodb.MongoClientSettings
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertNull
import kotlin.test.assertTrue
import org.bson.codecs.DecoderContext
import org.bson.codecs.kotlinx.samples.DataClassParameterized
import org.bson.codecs.kotlinx.samples.DataClassSealedInterface
import org.bson.codecs.kotlinx.samples.DataClassWithSimpleValues
import org.bson.codecs.kotlinx.samples.SealedInterface
import org.bson.conversions.Bson
import org.bson.json.JsonReader
import org.bson.types.ObjectId
import org.junit.jupiter.api.Test
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertNull
import kotlin.test.assertTrue

class KotlinSerializerCodecProviderTest {

Expand Down Expand Up @@ -60,4 +65,39 @@ class KotlinSerializerCodecProviderTest {
assertTrue { codec is KotlinSerializerCodec }
assertEquals(DataClassWithSimpleValues::class.java, codec.encoderClass)
}

@Test
fun testDataClassWithSimpleValuesFieldOrdering() {
val codec = MongoClientSettings.getDefaultCodecRegistry().get(DataClassWithSimpleValues::class.java)
val expected = DataClassWithSimpleValues('c', 0, 1, 22, 42L, 4.0f, 4.2, true, "String")

val numberLong = "\$numberLong"
val actual =
codec.decode(
JsonReader(
"""{"boolean": true, "byte": 0, "char": "c", "double": 4.2, "float": 4.0, "int": 22,
|"long": {"$numberLong": "42"}, "short": 1, "string": "String"}"""
.trimMargin()),
DecoderContext.builder().build())

assertEquals(expected, actual)
}

@Test
fun testDataClassSealedFieldOrdering() {
val codec = MongoClientSettings.getDefaultCodecRegistry().get(SealedInterface::class.java)

val objectId = ObjectId("111111111111111111111111")
val oid = "\$oid"
val expected = DataClassSealedInterface(objectId, "string")
val actual =
codec.decode(
JsonReader(
"""{"name": "string", "_id": {$oid: "${objectId.toHexString()}"},
|"_t": "org.bson.codecs.kotlinx.samples.DataClassSealedInterface"}"""
.trimMargin()),
DecoderContext.builder().build())

assertEquals(expected, actual)
}
}
Expand Up @@ -84,20 +84,23 @@ import org.bson.codecs.kotlinx.samples.DataClassWithSequence
import org.bson.codecs.kotlinx.samples.DataClassWithSimpleValues
import org.bson.codecs.kotlinx.samples.DataClassWithTriple
import org.bson.codecs.kotlinx.samples.Key
import org.bson.codecs.kotlinx.samples.SealedInterface
import org.bson.codecs.kotlinx.samples.ValueClass
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows

@OptIn(ExperimentalSerializationApi::class)
@Suppress("LargeClass")
class KotlinSerializerCodecTest {
private val numberLong = "\$numberLong"
private val oid = "\$oid"
private val emptyDocument = "{}"
private val altConfiguration =
BsonConfiguration(encodeDefaults = false, classDiscriminator = "_t", explicitNulls = true)

private val allBsonTypesJson =
"""{
| "id": {"${'$'}oid": "111111111111111111111111"},
| "id": {"$oid": "111111111111111111111111"},
| "arrayEmpty": [],
| "arraySimple": [{"${'$'}numberInt": "1"}, {"${'$'}numberInt": "2"}, {"${'$'}numberInt": "3"}],
| "arrayComplex": [{"a": {"${'$'}numberInt": "1"}}, {"a": {"${'$'}numberInt": "2"}}],
Expand Down Expand Up @@ -668,17 +671,49 @@ class KotlinSerializerCodecTest {
codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build())
}

assertThrows<MissingFieldException>("Invalid complex types") {
val data = BsonDocument.parse("""{"_id": "myId", "embedded": 123}""")
val codec = KotlinSerializerCodec.create<DataClassWithEmbedded>()
codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build())
}

assertThrows<IllegalArgumentException>("Failing init") {
val data = BsonDocument.parse("""{"id": "myId"}""")
val codec = KotlinSerializerCodec.create<DataClassWithFailingInit>()
codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build())
}

var exception =
assertThrows<SerializationException>("Invalid complex types - document") {
val data = BsonDocument.parse("""{"_id": "myId", "embedded": 123}""")
val codec = KotlinSerializerCodec.create<DataClassWithEmbedded>()
codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build())
}
assertEquals(
"Invalid data for `org.bson.codecs.kotlinx.samples.DataClassEmbedded` " +
"expected a bson document found: INT32",
exception.message)

exception =
assertThrows<SerializationException>("Invalid complex types - list") {
val data = BsonDocument.parse("""{"_id": "myId", "nested": 123}""")
val codec = KotlinSerializerCodec.create<DataClassListOfDataClasses>()
codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build())
}
assertEquals("Invalid data for `LIST` expected a bson array found: INT32", exception.message)

exception =
assertThrows<SerializationException>("Invalid complex types - map") {
val data = BsonDocument.parse("""{"_id": "myId", "nested": 123}""")
val codec = KotlinSerializerCodec.create<DataClassMapOfDataClasses>()
codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build())
}
assertEquals("Invalid data for `MAP` expected a bson document found: INT32", exception.message)

exception =
assertThrows<SerializationException>("Missing discriminator") {
val data = BsonDocument.parse("""{"_id": {"$oid": "111111111111111111111111"}, "name": "string"}""")
val codec = KotlinSerializerCodec.create<SealedInterface>()
codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build())
}
assertEquals(
"Missing required discriminator field `_t` for polymorphic class: " +
"`org.bson.codecs.kotlinx.samples.SealedInterface`.",
exception.message)
}

@Test
Expand Down
Expand Up @@ -245,6 +245,15 @@ data class DataClassOptionalBsonValues(

@Serializable @SerialName("C") data class DataClassSealedC(val c: String) : DataClassSealed()

@Serializable
sealed interface SealedInterface {
val name: String
}

@Serializable
data class DataClassSealedInterface(@Contextual @SerialName("_id") val id: ObjectId, override val name: String) :
SealedInterface

@Serializable data class DataClassListOfSealed(val items: List<DataClassSealed>)

interface DataClassOpen
Expand Down

0 comments on commit d9bedd4

Please sign in to comment.