Skip to content

Commit

Permalink
DGS-8901 Ensure logicalType flag passed to ReflectData
Browse files Browse the repository at this point in the history
  • Loading branch information
rayokota committed Jan 9, 2024
1 parent 10a588d commit 26a3423
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 8 deletions.
Expand Up @@ -16,6 +16,9 @@

package io.confluent.kafka.serializers;

import static io.confluent.kafka.schemaregistry.avro.AvroSchemaUtils.getReflectData;
import static io.confluent.kafka.schemaregistry.avro.AvroSchemaUtils.getReflectDataAllowNull;

import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
Expand Down Expand Up @@ -82,7 +85,7 @@ public DatumReader<?> load(IdentityPair<Schema, Schema> key) {
} else if (useSchemaReflection) {
return new ReflectDatumReader<>(writerSchema, finalReaderSchema,
avroUseLogicalTypeConverters
? AvroSchemaUtils.getReflectData()
? getReflectData()
: ReflectData.get());
} else if (useSpecificAvroReader) {
return new SpecificDatumReader<>(writerSchema, finalReaderSchema);
Expand Down Expand Up @@ -357,8 +360,9 @@ private Schema getSpecificReaderSchema(Schema writerSchema) {
}

private Schema getReflectionReaderSchema(Schema writerSchema) {
ReflectData reflectData = avroReflectionAllowNull ? ReflectData.AllowNull.get()
: ReflectData.get();
ReflectData reflectData = avroReflectionAllowNull
? (avroUseLogicalTypeConverters ? getReflectDataAllowNull() : ReflectData.AllowNull.get())
: (avroUseLogicalTypeConverters ? getReflectData() : ReflectData.get());
Class<?> readerClass = reflectData.getClass(writerSchema);
if (readerClass == null) {
throw new SerializationException("Could not find class "
Expand Down
Expand Up @@ -64,7 +64,7 @@ public byte[] serialize(String topic, Headers headers, Object record) {
}
AvroSchema schema = new AvroSchema(
AvroSchemaUtils.getSchema(record, useSchemaReflection,
avroReflectionAllowNull, removeJavaProperties));
avroReflectionAllowNull, avroUseLogicalTypeConverters, removeJavaProperties, true));
return serializeImpl(
getSubjectName(topic, isKey, record, schema), topic, headers, record, schema);
}
Expand Down
Expand Up @@ -30,6 +30,8 @@
import java.time.LocalTime;
import java.util.Arrays;

import java.util.Objects;
import java.util.UUID;
import org.apache.avro.*;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericRecord;
Expand Down Expand Up @@ -112,6 +114,10 @@ public KafkaAvroSerializerTest() {
KafkaAvroDeserializerConfig.SCHEMA_REGISTRY_URL_CONFIG, "bogus");
reflectionDecoderProps.setProperty(
KafkaAvroDeserializerConfig.SCHEMA_REFLECTION_CONFIG, "true");
reflectionDecoderProps.setProperty(
KafkaAvroDeserializerConfig.AVRO_REFLECTION_ALLOW_NULL_CONFIG, "true");
reflectionDecoderProps.setProperty(
KafkaAvroDeserializerConfig.AVRO_USE_LOGICAL_TYPE_CONVERTERS_CONFIG, "true");
reflectionAvroDecoder = new KafkaAvroDecoder(
schemaRegistry, new VerifiableProperties(reflectionDecoderProps));
}
Expand Down Expand Up @@ -1022,6 +1028,73 @@ public void testKafkaAvroSerializerReflectionRecordWithNullField() {
assertEquals(widget, obj);
}

@Test
public void testKafkaAvroSerializerReflectionRecordWithLogicalType() {
byte[] bytes;
Object obj;

RecordWithUUID record = new RecordWithUUID();
record.uuid = UUID.randomUUID();

Schema schema = AvroSchemaUtils.getReflectData().getSchema(record.getClass());

Map configs = ImmutableMap.of(
KafkaAvroDeserializerConfig.SCHEMA_REGISTRY_URL_CONFIG, "bogus",
AbstractKafkaSchemaSerDeConfig.SCHEMA_REFLECTION_CONFIG, true,
KafkaAvroSerializerConfig.AVRO_USE_LOGICAL_TYPE_CONVERTERS_CONFIG, true
);
reflectionAvroDeserializer.configure(configs, false);
reflectionAvroSerializer.configure(configs, false);

bytes = reflectionAvroSerializer.serialize(topic, record);
obj = reflectionAvroDecoder.fromBytes(bytes, schema);
assertTrue(
"Returned object should be a RecordWithUUID",
RecordWithUUID.class.isInstance(obj)
);
assertEquals(record, obj);

obj = reflectionAvroDeserializer.deserialize(topic, bytes);
assertTrue(
"Returned object should be a RecordWithUUID",
RecordWithUUID.class.isInstance(obj)
);
assertEquals(record, obj);
}

@Test
public void testKafkaAvroSerializerReflectionRecordWithLogicalTypeNullField() {
byte[] bytes;
Object obj;

RecordWithUUID record = new RecordWithUUID();

Schema schema = AvroSchemaUtils.getReflectDataAllowNull().getSchema(record.getClass());

Map configs = ImmutableMap.of(
KafkaAvroDeserializerConfig.SCHEMA_REGISTRY_URL_CONFIG, "bogus",
AbstractKafkaSchemaSerDeConfig.SCHEMA_REFLECTION_CONFIG, true,
KafkaAvroSerializerConfig.AVRO_REFLECTION_ALLOW_NULL_CONFIG, true,
KafkaAvroSerializerConfig.AVRO_USE_LOGICAL_TYPE_CONVERTERS_CONFIG, true
);
reflectionAvroDeserializer.configure(configs, false);
reflectionAvroSerializer.configure(configs, false);

bytes = reflectionAvroSerializer.serialize(topic, record);
obj = reflectionAvroDecoder.fromBytes(bytes, schema);
assertTrue(
"Returned object should be a RecordWithUUID",
RecordWithUUID.class.isInstance(obj)
);
assertEquals(record, obj);

obj = reflectionAvroDeserializer.deserialize(topic, bytes);
assertTrue(
"Returned object should be a RecordWithUUID",
RecordWithUUID.class.isInstance(obj)
);
assertEquals(record, obj);
}

@Test
public void testKafkaAvroSerializerReflectionRecordWithProjection() {
Expand Down Expand Up @@ -1212,4 +1285,25 @@ public void testResolvedFormat() throws IOException, RestClientException {
+ "\"fields\":[{\"name\":\"accountNumber\",\"type\":\"string\"}]}]";
assertEquals(expectedResolved, schema.formattedString(Format.RESOLVED.symbol()));
}

static class RecordWithUUID {
UUID uuid;

@Override
public int hashCode() {
return uuid.hashCode();
}

@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (!(obj instanceof RecordWithUUID)) {
return false;
}
RecordWithUUID that = (RecordWithUUID) obj;
return Objects.equals(this.uuid, that.uuid);
}
}
}
Expand Up @@ -55,6 +55,7 @@
import org.apache.avro.io.EncoderFactory;
import org.apache.avro.io.JsonEncoder;
import org.apache.avro.reflect.ReflectData;
import org.apache.avro.reflect.ReflectData.AllowNull;
import org.apache.avro.reflect.ReflectDatumWriter;
import org.apache.avro.specific.SpecificDatumWriter;
import org.apache.avro.specific.SpecificRecord;
Expand All @@ -78,10 +79,12 @@ public class AvroSchemaUtils {

private static final GenericData GENERIC_DATA_INSTANCE = new GenericData();
private static final ReflectData REFLECT_DATA_INSTANCE = new ReflectData();
private static final ReflectData REFLECT_DATA_ALLOW_NULL_INSTANCE = new AllowNull();

static {
addLogicalTypeConversion(GENERIC_DATA_INSTANCE);
addLogicalTypeConversion(REFLECT_DATA_INSTANCE);
addLogicalTypeConversion(REFLECT_DATA_ALLOW_NULL_INSTANCE);
}

public static GenericData getGenericData() {
Expand All @@ -92,6 +95,10 @@ public static ReflectData getReflectData() {
return REFLECT_DATA_INSTANCE;
}

public static ReflectData getReflectDataAllowNull() {
return REFLECT_DATA_ALLOW_NULL_INSTANCE;
}

public static void addLogicalTypeConversion(GenericData avroData) {
avroData.addLogicalTypeConversion(new Conversions.DecimalConversion());
avroData.addLogicalTypeConversion(new Conversions.UUIDConversion());
Expand Down Expand Up @@ -169,6 +176,13 @@ public static Schema getSchema(Object object, boolean useReflection,
public static Schema getSchema(Object object, boolean useReflection,
boolean reflectionAllowNull, boolean removeJavaProperties,
boolean throwError) {
return getSchema(object, useReflection, reflectionAllowNull, false,
removeJavaProperties, throwError);
}

public static Schema getSchema(Object object, boolean useReflection,
boolean reflectionAllowNull, boolean useLogicalTypeConverters,
boolean removeJavaProperties, boolean throwError) {
if (object == null) {
return primitiveSchemas.get("Null");
} else if (object instanceof Boolean) {
Expand All @@ -186,8 +200,10 @@ public static Schema getSchema(Object object, boolean useReflection,
} else if (object instanceof byte[] || object instanceof ByteBuffer) {
return primitiveSchemas.get("Bytes");
} else if (useReflection) {
Schema schema = reflectionAllowNull ? ReflectData.AllowNull.get().getSchema(object.getClass())
: ReflectData.get().getSchema(object.getClass());
ReflectData reflectData = reflectionAllowNull
? (useLogicalTypeConverters ? getReflectDataAllowNull() : ReflectData.AllowNull.get())
: (useLogicalTypeConverters ? getReflectData() : ReflectData.get());
Schema schema = reflectData.getSchema(object.getClass());
if (schema == null) {
throw new SerializationException("Schema is null for object of class " + object.getClass()
.getCanonicalName());
Expand Down Expand Up @@ -221,8 +237,10 @@ public static Schema getSchema(Object object, boolean useReflection,

} else {
// Try reflection as last resort
Schema schema = reflectionAllowNull ? ReflectData.AllowNull.get().getSchema(object.getClass())
: ReflectData.get().getSchema(object.getClass());
ReflectData reflectData = reflectionAllowNull
? (useLogicalTypeConverters ? getReflectDataAllowNull() : ReflectData.AllowNull.get())
: (useLogicalTypeConverters ? getReflectData() : ReflectData.get());
Schema schema = reflectData.getSchema(object.getClass());
if (schema == null) {
throw new SerializationException("Schema is null for object of class " + object.getClass()
.getCanonicalName());
Expand Down

0 comments on commit 26a3423

Please sign in to comment.