diff --git a/avro-serde/src/main/java/io/confluent/kafka/streams/serdes/avro/ReflectionAvroDeserializer.java b/avro-serde/src/main/java/io/confluent/kafka/streams/serdes/avro/ReflectionAvroDeserializer.java index 90f9c9c3434..befc0884da5 100644 --- a/avro-serde/src/main/java/io/confluent/kafka/streams/serdes/avro/ReflectionAvroDeserializer.java +++ b/avro-serde/src/main/java/io/confluent/kafka/streams/serdes/avro/ReflectionAvroDeserializer.java @@ -40,11 +40,24 @@ public class ReflectionAvroDeserializer implements Deserializer { private final KafkaAvroDeserializer inner; private final Schema schema; + public ReflectionAvroDeserializer() { + this.schema = null; + this.inner = new KafkaAvroDeserializer(); + } + public ReflectionAvroDeserializer(Class type) { this.schema = ReflectData.get().getSchema(type); this.inner = new KafkaAvroDeserializer(); } + /** + * For testing purposes only. + */ + ReflectionAvroDeserializer(final SchemaRegistryClient client) { + this.schema = null; + this.inner = new KafkaAvroDeserializer(client); + } + /** * For testing purposes only. */ @@ -61,6 +74,7 @@ public void configure(final Map deserializerConfig, isDeserializerForRecordKeys); } + @SuppressWarnings("unchecked") @Override public T deserialize(final String topic, final byte[] bytes) { return (T) inner.deserialize(topic, bytes, schema); diff --git a/avro-serde/src/main/java/io/confluent/kafka/streams/serdes/avro/ReflectionAvroSerde.java b/avro-serde/src/main/java/io/confluent/kafka/streams/serdes/avro/ReflectionAvroSerde.java index 4dc1ac42ba6..9c5a9c03302 100644 --- a/avro-serde/src/main/java/io/confluent/kafka/streams/serdes/avro/ReflectionAvroSerde.java +++ b/avro-serde/src/main/java/io/confluent/kafka/streams/serdes/avro/ReflectionAvroSerde.java @@ -72,11 +72,28 @@ public class ReflectionAvroSerde implements Serde { private final Serde inner; + public ReflectionAvroSerde() { + inner = Serdes + .serdeFrom(new ReflectionAvroSerializer<>(), new ReflectionAvroDeserializer<>()); + } + public ReflectionAvroSerde(Class type) { inner = Serdes .serdeFrom(new ReflectionAvroSerializer<>(), new ReflectionAvroDeserializer<>(type)); } + /** + * For testing purposes only. + */ + public ReflectionAvroSerde(final SchemaRegistryClient client) { + if (client == null) { + throw new IllegalArgumentException("schema registry client must not be null"); + } + inner = Serdes.serdeFrom( + new ReflectionAvroSerializer<>(client), + new ReflectionAvroDeserializer<>(client)); + } + /** * For testing purposes only. */ diff --git a/avro-serde/src/test/java/io/confluent/kafka/streams/serdes/avro/ReflectionAvroSerdeTest.java b/avro-serde/src/test/java/io/confluent/kafka/streams/serdes/avro/ReflectionAvroSerdeGenericTest.java similarity index 85% rename from avro-serde/src/test/java/io/confluent/kafka/streams/serdes/avro/ReflectionAvroSerdeTest.java rename to avro-serde/src/test/java/io/confluent/kafka/streams/serdes/avro/ReflectionAvroSerdeGenericTest.java index 34358329ab0..0fd0499f184 100644 --- a/avro-serde/src/test/java/io/confluent/kafka/streams/serdes/avro/ReflectionAvroSerdeTest.java +++ b/avro-serde/src/test/java/io/confluent/kafka/streams/serdes/avro/ReflectionAvroSerdeGenericTest.java @@ -28,14 +28,13 @@ import java.util.Map; import org.junit.Test; -public class ReflectionAvroSerdeTest { +public class ReflectionAvroSerdeGenericTest { private static final String ANY_TOPIC = "any-topic"; - private static ReflectionAvroSerde - createConfiguredSerdeForRecordValues(Class type) { + private static ReflectionAvroSerde createConfiguredSerde() { SchemaRegistryClient schemaRegistryClient = new MockSchemaRegistryClient(); - ReflectionAvroSerde serde = new ReflectionAvroSerde<>(schemaRegistryClient, type); + ReflectionAvroSerde serde = new ReflectionAvroSerde<>(schemaRegistryClient); Map serdeConfig = new HashMap<>(); serdeConfig.put(AbstractKafkaAvroSerDeConfig.SCHEMA_REGISTRY_URL_CONFIG, "fake"); serde.configure(serdeConfig, false); @@ -45,7 +44,7 @@ public class ReflectionAvroSerdeTest { @Test public void shouldRoundTripRecords() { // Given - ReflectionAvroSerde serde = createConfiguredSerdeForRecordValues(Widget.class); + ReflectionAvroSerde serde = createConfiguredSerde(); Widget record = new Widget("alice"); // When @@ -63,7 +62,7 @@ public void shouldRoundTripRecords() { @Test public void shouldRoundTripNullRecordsToNull() { // Given - ReflectionAvroSerde serde = createConfiguredSerdeForRecordValues(Widget.class); + ReflectionAvroSerde serde = createConfiguredSerde(); // When Widget roundtrippedRecord = serde.deserializer().deserialize( @@ -79,7 +78,7 @@ public void shouldRoundTripNullRecordsToNull() { @Test(expected = IllegalArgumentException.class) public void shouldFailWhenInstantiatedWithNullSchemaRegistryClient() { - new ReflectionAvroSerde<>(null, Widget.class); + new ReflectionAvroSerde<>((SchemaRegistryClient)null); } } \ No newline at end of file diff --git a/avro-serde/src/test/java/io/confluent/kafka/streams/serdes/avro/ReflectionAvroSerdeSpecificTest.java b/avro-serde/src/test/java/io/confluent/kafka/streams/serdes/avro/ReflectionAvroSerdeSpecificTest.java new file mode 100644 index 00000000000..0c43b0cfd8b --- /dev/null +++ b/avro-serde/src/test/java/io/confluent/kafka/streams/serdes/avro/ReflectionAvroSerdeSpecificTest.java @@ -0,0 +1,95 @@ +/* + * Copyright 2018 Confluent Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.confluent.kafka.streams.serdes.avro; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.junit.Assert.assertThat; + +import io.confluent.kafka.example.Widget; +import io.confluent.kafka.schemaregistry.client.MockSchemaRegistryClient; +import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient; +import io.confluent.kafka.serializers.AbstractKafkaAvroSerDeConfig; +import java.util.HashMap; +import java.util.Map; +import org.junit.Test; + +public class ReflectionAvroSerdeSpecificTest { + + private static final String ANY_TOPIC = "any-topic"; + + private static ReflectionAvroSerde + createConfiguredSerdeForRecordValues(Class type) { + SchemaRegistryClient schemaRegistryClient = new MockSchemaRegistryClient(); + ReflectionAvroSerde serde = new ReflectionAvroSerde<>(schemaRegistryClient, type); + Map serdeConfig = new HashMap<>(); + serdeConfig.put(AbstractKafkaAvroSerDeConfig.SCHEMA_REGISTRY_URL_CONFIG, "fake"); + serde.configure(serdeConfig, false); + return serde; + } + + private static ReflectionAvroSerde + createConfiguredSerdeForAnyValues() { + SchemaRegistryClient schemaRegistryClient = new MockSchemaRegistryClient(); + ReflectionAvroSerde serde = new ReflectionAvroSerde<>(schemaRegistryClient); + Map serdeConfig = new HashMap<>(); + serdeConfig.put(AbstractKafkaAvroSerDeConfig.SCHEMA_REGISTRY_URL_CONFIG, "fake"); + serde.configure(serdeConfig, false); + return serde; + } + + @Test + public void shouldRoundTripRecords() { + // Given + ReflectionAvroSerde serde = createConfiguredSerdeForRecordValues(Widget.class); + Widget record = new Widget("alice"); + + // When + Widget roundtrippedRecord = serde.deserializer().deserialize( + ANY_TOPIC, + serde.serializer().serialize(ANY_TOPIC, record)); + + // Then + assertThat(roundtrippedRecord, equalTo(record)); + + // Cleanup + serde.close(); + } + + @Test + public void shouldRoundTripNullRecordsToNull() { + // Given + ReflectionAvroSerde serde = createConfiguredSerdeForRecordValues(Widget.class); + + // When + Widget roundtrippedRecord = serde.deserializer().deserialize( + ANY_TOPIC, + serde.serializer().serialize(ANY_TOPIC, null)); + + // Then + assertThat(roundtrippedRecord, nullValue()); + + // Cleanup + serde.close(); + } + + @Test(expected = IllegalArgumentException.class) + public void shouldFailWhenInstantiatedWithNullSchemaRegistryClient() { + new ReflectionAvroSerde<>(null, Widget.class); + } + +} \ No newline at end of file diff --git a/avro-serializer/src/main/java/io/confluent/kafka/serializers/AbstractKafkaAvroDeserializer.java b/avro-serializer/src/main/java/io/confluent/kafka/serializers/AbstractKafkaAvroDeserializer.java index cec908cf872..be7e23b67e3 100644 --- a/avro-serializer/src/main/java/io/confluent/kafka/serializers/AbstractKafkaAvroDeserializer.java +++ b/avro-serializer/src/main/java/io/confluent/kafka/serializers/AbstractKafkaAvroDeserializer.java @@ -26,7 +26,7 @@ import org.apache.avro.specific.SpecificDatumReader; import org.apache.avro.specific.SpecificRecord; import org.apache.kafka.common.errors.SerializationException; - +import org.apache.avro.reflect.ReflectData; import org.apache.avro.reflect.ReflectDatumReader; import java.io.IOException; @@ -177,58 +177,92 @@ protected GenericContainerWithVersion deserializeWithSchemaAndVersion( } protected DatumReader getDatumReader(Schema writerSchema, Schema readerSchema) { + // normalize reader schema + readerSchema = getReaderSchema(writerSchema, readerSchema); boolean writerSchemaIsPrimitive = AvroSchemaUtils.getPrimitiveSchemas().values().contains(writerSchema); - // do not use SpecificDatumReader if writerSchema is a primitive - if (useSchemaReflection && !writerSchemaIsPrimitive) { - if (readerSchema == null) { - throw new SerializationException( - "Reader schema cannot be null when using Avro schema reflection"); - } + if (writerSchemaIsPrimitive) { + return new GenericDatumReader<>(writerSchema, readerSchema); + } else if (useSchemaReflection) { return new ReflectDatumReader<>(writerSchema, readerSchema); - } else if (useSpecificAvroReader && !writerSchemaIsPrimitive) { - if (readerSchema == null) { - readerSchema = getReaderSchema(writerSchema); - } + } else if (useSpecificAvroReader) { return new SpecificDatumReader<>(writerSchema, readerSchema); } else { - if (readerSchema == null) { - return new GenericDatumReader<>(writerSchema); - } return new GenericDatumReader<>(writerSchema, readerSchema); } } - @SuppressWarnings("unchecked") - private Schema getReaderSchema(Schema writerSchema) { - Schema readerSchema = readerSchemaCache.get(writerSchema.getFullName()); - if (readerSchema == null) { - Class readerClass = SpecificData.get().getClass(writerSchema); - if (readerClass != null) { - try { - readerSchema = readerClass.newInstance().getSchema(); - } catch (InstantiationException e) { - throw new SerializationException(writerSchema.getFullName() - + " specified by the " - + "writers schema could not be instantiated to " - + "find the readers schema."); - } catch (IllegalAccessException e) { - throw new SerializationException(writerSchema.getFullName() - + " specified by the " - + "writers schema is not allowed to be instantiated " - + "to find the readers schema."); - } - readerSchemaCache.put(writerSchema.getFullName(), readerSchema); - } else { - throw new SerializationException("Could not find class " - + writerSchema.getFullName() - + " specified in writer's schema whilst finding reader's " - + "schema for a SpecificRecord."); - } + /** + * Normalizes the reader schema, puts the resolved schema into the cache. + *
  • + *
      if the reader schema is provided, use the provided one
    + *
      if the reader schema is cached for the writer schema full name, use the cached value
    + *
      if the writer schema is primitive, use the writer one
    + *
      if schema reflection is used, generate one from the class referred by writer schema
    + *
      if generated classes are used, query the class referred by writer schema
    + *
      otherwise use the writer schema
    + *
  • + */ + private Schema getReaderSchema(Schema writerSchema, Schema readerSchema) { + if (readerSchema != null) { + return readerSchema; + } + readerSchema = readerSchemaCache.get(writerSchema.getFullName()); + if (readerSchema != null) { + return readerSchema; + } + boolean writerSchemaIsPrimitive = + AvroSchemaUtils.getPrimitiveSchemas().values().contains(writerSchema); + if (writerSchemaIsPrimitive) { + readerSchema = writerSchema; + } else if (useSchemaReflection) { + readerSchema = getReflectionReaderSchema(writerSchema); + readerSchemaCache.put(writerSchema.getFullName(), readerSchema); + } else if (useSpecificAvroReader) { + readerSchema = getSpecificReaderSchema(writerSchema); + readerSchemaCache.put(writerSchema.getFullName(), readerSchema); + } else { + readerSchema = writerSchema; } return readerSchema; } + @SuppressWarnings("unchecked") + private Schema getSpecificReaderSchema(Schema writerSchema) { + Class readerClass = SpecificData.get().getClass(writerSchema); + if (readerClass == null) { + throw new SerializationException("Could not find class " + + writerSchema.getFullName() + + " specified in writer's schema whilst finding reader's " + + "schema for a SpecificRecord."); + } + try { + return readerClass.newInstance().getSchema(); + } catch (InstantiationException e) { + throw new SerializationException(writerSchema.getFullName() + + " specified by the " + + "writers schema could not be instantiated to " + + "find the readers schema."); + } catch (IllegalAccessException e) { + throw new SerializationException(writerSchema.getFullName() + + " specified by the " + + "writers schema is not allowed to be instantiated " + + "to find the readers schema."); + } + } + + private Schema getReflectionReaderSchema(Schema writerSchema) { + // shall we use ReflectData.AllowNull.get() instead? + Class readerClass = ReflectData.get().getClass(writerSchema); + if (readerClass == null) { + throw new SerializationException("Could not find class " + + writerSchema.getFullName() + + " specified in writer's schema whilst finding reader's " + + "schema for a reflected class."); + } + return ReflectData.get().getSchema(readerClass); + } + class DeserializationContext { private final String topic; private final Boolean isKey;