From d9bedd49ab2ece07d08799ea2e69b9b8c6225bb9 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Fri, 15 Mar 2024 14:36:03 +0000 Subject: [PATCH] Support discriminators not being the first field when decoding (#1324) JAVA-5304 --- .../org/bson/codecs/kotlinx/BsonDecoder.kt | 100 ++++++++++++++---- .../KotlinSerializerCodecProviderTest.kt | 48 ++++++++- .../kotlinx/KotlinSerializerCodecTest.kt | 49 +++++++-- .../codecs/kotlinx/samples/DataClasses.kt | 9 ++ 4 files changed, 172 insertions(+), 34 deletions(-) diff --git a/bson-kotlinx/src/main/kotlin/org/bson/codecs/kotlinx/BsonDecoder.kt b/bson-kotlinx/src/main/kotlin/org/bson/codecs/kotlinx/BsonDecoder.kt index b4cbad3b9dd..435964d4ac0 100644 --- a/bson-kotlinx/src/main/kotlin/org/bson/codecs/kotlinx/BsonDecoder.kt +++ b/bson-kotlinx/src/main/kotlin/org/bson/codecs/kotlinx/BsonDecoder.kt @@ -31,6 +31,7 @@ import kotlinx.serialization.modules.SerializersModule import org.bson.AbstractBsonReader import org.bson.BsonInvalidOperationException import org.bson.BsonReader +import org.bson.BsonReaderMark import org.bson.BsonType import org.bson.BsonValue import org.bson.codecs.BsonValueCodec @@ -68,6 +69,20 @@ internal open class DefaultBsonDecoder( val validKeyKinds = setOf(PrimitiveKind.STRING, PrimitiveKind.CHAR, SerialKind.ENUM) val bsonValueCodec = BsonValueCodec() const val UNKNOWN_INDEX = -10 + fun validateCurrentBsonType( + reader: AbstractBsonReader, + expectedType: BsonType, + descriptor: SerialDescriptor, + actualType: (descriptor: SerialDescriptor) -> String = { it.kind.toString() } + ) { + reader.currentBsonType?.let { + if (it != expectedType) { + throw SerializationException( + "Invalid data for `${actualType(descriptor)}` expected a bson " + + "${expectedType.name.lowercase()} found: ${reader.currentBsonType}") + } + } + } } private fun initElementMetadata(descriptor: SerialDescriptor) { @@ -119,29 +134,14 @@ internal open class DefaultBsonDecoder( @Suppress("ReturnCount") override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder { - when (descriptor.kind) { - is StructureKind.LIST -> { - reader.readStartArray() - return BsonArrayDecoder(reader, serializersModule, configuration) - } - is PolymorphicKind -> { - reader.readStartDocument() - return PolymorphicDecoder(reader, serializersModule, configuration) - } + return when (descriptor.kind) { + is StructureKind.LIST -> BsonArrayDecoder(descriptor, reader, serializersModule, configuration) + is PolymorphicKind -> PolymorphicDecoder(descriptor, reader, serializersModule, configuration) is StructureKind.CLASS, - StructureKind.OBJECT -> { - val current = reader.currentBsonType - if (current == null || current == BsonType.DOCUMENT) { - reader.readStartDocument() - } - } - is StructureKind.MAP -> { - reader.readStartDocument() - return BsonDocumentDecoder(reader, serializersModule, configuration) - } + StructureKind.OBJECT -> BsonDocumentDecoder(descriptor, reader, serializersModule, configuration) + is StructureKind.MAP -> MapDecoder(descriptor, reader, serializersModule, configuration) else -> throw SerializationException("Primitives are not supported at top-level") } - return DefaultBsonDecoder(reader, serializersModule, configuration) } override fun endStructure(descriptor: SerialDescriptor) { @@ -194,10 +194,17 @@ internal open class DefaultBsonDecoder( @OptIn(ExperimentalSerializationApi::class) private class BsonArrayDecoder( + descriptor: SerialDescriptor, reader: AbstractBsonReader, serializersModule: SerializersModule, configuration: BsonConfiguration ) : DefaultBsonDecoder(reader, serializersModule, configuration) { + + init { + validateCurrentBsonType(reader, BsonType.ARRAY, descriptor) + reader.readStartArray() + } + private var index = 0 override fun decodeElementIndex(descriptor: SerialDescriptor): Int { val nextType = reader.readBsonType() @@ -208,18 +215,46 @@ private class BsonArrayDecoder( @OptIn(ExperimentalSerializationApi::class) private class PolymorphicDecoder( + descriptor: SerialDescriptor, reader: AbstractBsonReader, serializersModule: SerializersModule, configuration: BsonConfiguration ) : DefaultBsonDecoder(reader, serializersModule, configuration) { private var index = 0 + private var mark: BsonReaderMark? - override fun decodeSerializableValue(deserializer: DeserializationStrategy): T = - deserializer.deserialize(DefaultBsonDecoder(reader, serializersModule, configuration)) + init { + mark = reader.mark + validateCurrentBsonType(reader, BsonType.DOCUMENT, descriptor) { it.serialName } + reader.readStartDocument() + } + + override fun decodeSerializableValue(deserializer: DeserializationStrategy): T { + mark?.let { + it.reset() + mark = null + } + return deserializer.deserialize(DefaultBsonDecoder(reader, serializersModule, configuration)) + } override fun decodeElementIndex(descriptor: SerialDescriptor): Int { + var found = false return when (index) { - 0 -> index++ + 0 -> { + while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) { + if (reader.readName() == configuration.classDiscriminator) { + found = true + break + } + reader.skipValue() + } + if (!found) { + throw SerializationException( + "Missing required discriminator field `${configuration.classDiscriminator}` " + + "for polymorphic class: `${descriptor.serialName}`.") + } + index++ + } 1 -> index++ else -> DECODE_DONE } @@ -228,6 +263,20 @@ private class PolymorphicDecoder( @OptIn(ExperimentalSerializationApi::class) private class BsonDocumentDecoder( + descriptor: SerialDescriptor, + reader: AbstractBsonReader, + serializersModule: SerializersModule, + configuration: BsonConfiguration +) : DefaultBsonDecoder(reader, serializersModule, configuration) { + init { + validateCurrentBsonType(reader, BsonType.DOCUMENT, descriptor) { it.serialName } + reader.readStartDocument() + } +} + +@OptIn(ExperimentalSerializationApi::class) +private class MapDecoder( + descriptor: SerialDescriptor, reader: AbstractBsonReader, serializersModule: SerializersModule, configuration: BsonConfiguration @@ -236,6 +285,11 @@ private class BsonDocumentDecoder( private var index = 0 private var isKey = false + init { + validateCurrentBsonType(reader, BsonType.DOCUMENT, descriptor) + reader.readStartDocument() + } + override fun decodeString(): String { return if (isKey) { reader.readName() diff --git a/bson-kotlinx/src/test/kotlin/org/bson/codecs/kotlinx/KotlinSerializerCodecProviderTest.kt b/bson-kotlinx/src/test/kotlin/org/bson/codecs/kotlinx/KotlinSerializerCodecProviderTest.kt index 0870e2033e9..e05fc8f34f1 100644 --- a/bson-kotlinx/src/test/kotlin/org/bson/codecs/kotlinx/KotlinSerializerCodecProviderTest.kt +++ b/bson-kotlinx/src/test/kotlin/org/bson/codecs/kotlinx/KotlinSerializerCodecProviderTest.kt @@ -16,14 +16,19 @@ package org.bson.codecs.kotlinx import com.mongodb.MongoClientSettings -import kotlin.test.assertEquals -import kotlin.test.assertNotNull -import kotlin.test.assertNull -import kotlin.test.assertTrue +import org.bson.codecs.DecoderContext import org.bson.codecs.kotlinx.samples.DataClassParameterized +import org.bson.codecs.kotlinx.samples.DataClassSealedInterface import org.bson.codecs.kotlinx.samples.DataClassWithSimpleValues +import org.bson.codecs.kotlinx.samples.SealedInterface import org.bson.conversions.Bson +import org.bson.json.JsonReader +import org.bson.types.ObjectId import org.junit.jupiter.api.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue class KotlinSerializerCodecProviderTest { @@ -60,4 +65,39 @@ class KotlinSerializerCodecProviderTest { assertTrue { codec is KotlinSerializerCodec } assertEquals(DataClassWithSimpleValues::class.java, codec.encoderClass) } + + @Test + fun testDataClassWithSimpleValuesFieldOrdering() { + val codec = MongoClientSettings.getDefaultCodecRegistry().get(DataClassWithSimpleValues::class.java) + val expected = DataClassWithSimpleValues('c', 0, 1, 22, 42L, 4.0f, 4.2, true, "String") + + val numberLong = "\$numberLong" + val actual = + codec.decode( + JsonReader( + """{"boolean": true, "byte": 0, "char": "c", "double": 4.2, "float": 4.0, "int": 22, + |"long": {"$numberLong": "42"}, "short": 1, "string": "String"}""" + .trimMargin()), + DecoderContext.builder().build()) + + assertEquals(expected, actual) + } + + @Test + fun testDataClassSealedFieldOrdering() { + val codec = MongoClientSettings.getDefaultCodecRegistry().get(SealedInterface::class.java) + + val objectId = ObjectId("111111111111111111111111") + val oid = "\$oid" + val expected = DataClassSealedInterface(objectId, "string") + val actual = + codec.decode( + JsonReader( + """{"name": "string", "_id": {$oid: "${objectId.toHexString()}"}, + |"_t": "org.bson.codecs.kotlinx.samples.DataClassSealedInterface"}""" + .trimMargin()), + DecoderContext.builder().build()) + + assertEquals(expected, actual) + } } diff --git a/bson-kotlinx/src/test/kotlin/org/bson/codecs/kotlinx/KotlinSerializerCodecTest.kt b/bson-kotlinx/src/test/kotlin/org/bson/codecs/kotlinx/KotlinSerializerCodecTest.kt index 146e897c59b..14fcfa8a01c 100644 --- a/bson-kotlinx/src/test/kotlin/org/bson/codecs/kotlinx/KotlinSerializerCodecTest.kt +++ b/bson-kotlinx/src/test/kotlin/org/bson/codecs/kotlinx/KotlinSerializerCodecTest.kt @@ -84,20 +84,23 @@ import org.bson.codecs.kotlinx.samples.DataClassWithSequence import org.bson.codecs.kotlinx.samples.DataClassWithSimpleValues import org.bson.codecs.kotlinx.samples.DataClassWithTriple import org.bson.codecs.kotlinx.samples.Key +import org.bson.codecs.kotlinx.samples.SealedInterface import org.bson.codecs.kotlinx.samples.ValueClass import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows @OptIn(ExperimentalSerializationApi::class) +@Suppress("LargeClass") class KotlinSerializerCodecTest { private val numberLong = "\$numberLong" + private val oid = "\$oid" private val emptyDocument = "{}" private val altConfiguration = BsonConfiguration(encodeDefaults = false, classDiscriminator = "_t", explicitNulls = true) private val allBsonTypesJson = """{ - | "id": {"${'$'}oid": "111111111111111111111111"}, + | "id": {"$oid": "111111111111111111111111"}, | "arrayEmpty": [], | "arraySimple": [{"${'$'}numberInt": "1"}, {"${'$'}numberInt": "2"}, {"${'$'}numberInt": "3"}], | "arrayComplex": [{"a": {"${'$'}numberInt": "1"}}, {"a": {"${'$'}numberInt": "2"}}], @@ -668,17 +671,49 @@ class KotlinSerializerCodecTest { codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build()) } - assertThrows("Invalid complex types") { - val data = BsonDocument.parse("""{"_id": "myId", "embedded": 123}""") - val codec = KotlinSerializerCodec.create() - codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build()) - } - assertThrows("Failing init") { val data = BsonDocument.parse("""{"id": "myId"}""") val codec = KotlinSerializerCodec.create() codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build()) } + + var exception = + assertThrows("Invalid complex types - document") { + val data = BsonDocument.parse("""{"_id": "myId", "embedded": 123}""") + val codec = KotlinSerializerCodec.create() + codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build()) + } + assertEquals( + "Invalid data for `org.bson.codecs.kotlinx.samples.DataClassEmbedded` " + + "expected a bson document found: INT32", + exception.message) + + exception = + assertThrows("Invalid complex types - list") { + val data = BsonDocument.parse("""{"_id": "myId", "nested": 123}""") + val codec = KotlinSerializerCodec.create() + codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build()) + } + assertEquals("Invalid data for `LIST` expected a bson array found: INT32", exception.message) + + exception = + assertThrows("Invalid complex types - map") { + val data = BsonDocument.parse("""{"_id": "myId", "nested": 123}""") + val codec = KotlinSerializerCodec.create() + codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build()) + } + assertEquals("Invalid data for `MAP` expected a bson document found: INT32", exception.message) + + exception = + assertThrows("Missing discriminator") { + val data = BsonDocument.parse("""{"_id": {"$oid": "111111111111111111111111"}, "name": "string"}""") + val codec = KotlinSerializerCodec.create() + codec?.decode(BsonDocumentReader(data), DecoderContext.builder().build()) + } + assertEquals( + "Missing required discriminator field `_t` for polymorphic class: " + + "`org.bson.codecs.kotlinx.samples.SealedInterface`.", + exception.message) } @Test diff --git a/bson-kotlinx/src/test/kotlin/org/bson/codecs/kotlinx/samples/DataClasses.kt b/bson-kotlinx/src/test/kotlin/org/bson/codecs/kotlinx/samples/DataClasses.kt index ea5e3fea3cd..0326827d4a7 100644 --- a/bson-kotlinx/src/test/kotlin/org/bson/codecs/kotlinx/samples/DataClasses.kt +++ b/bson-kotlinx/src/test/kotlin/org/bson/codecs/kotlinx/samples/DataClasses.kt @@ -245,6 +245,15 @@ data class DataClassOptionalBsonValues( @Serializable @SerialName("C") data class DataClassSealedC(val c: String) : DataClassSealed() +@Serializable +sealed interface SealedInterface { + val name: String +} + +@Serializable +data class DataClassSealedInterface(@Contextual @SerialName("_id") val id: ObjectId, override val name: String) : + SealedInterface + @Serializable data class DataClassListOfSealed(val items: List) interface DataClassOpen