Skip to content

Commit

Permalink
Reflection serde (#1281)
Browse files Browse the repository at this point in the history
* enabled default constructor

Default constructor allows the SerDe to be used as default serde in KStreams config.

* added logic to resolve the reflection schema when it is not provided during construction

* fixed whitespace warning

* added missing constructors

* cloned the original test case to support dynamic type recognition

* cache only generated schemas

The failure in testVersionMaintained was caused by caching the schema for given name. In case
of GenericDatumReader the schema is as it comes from the serialized record. This means
if the schema evolves, there could be multiple representations of the same avro record.
  • Loading branch information
piotrsmolinski authored and rayokota committed Jan 13, 2020
1 parent ea2fcfb commit f8ca0ce
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 47 deletions.
Expand Up @@ -40,11 +40,24 @@ public class ReflectionAvroDeserializer<T> implements Deserializer<T> {
private final KafkaAvroDeserializer inner;
private final Schema schema;

public ReflectionAvroDeserializer() {
this.schema = null;
this.inner = new KafkaAvroDeserializer();
}

public ReflectionAvroDeserializer(Class<T> type) {
this.schema = ReflectData.get().getSchema(type);
this.inner = new KafkaAvroDeserializer();
}

/**
* For testing purposes only.
*/
ReflectionAvroDeserializer(final SchemaRegistryClient client) {
this.schema = null;
this.inner = new KafkaAvroDeserializer(client);
}

/**
* For testing purposes only.
*/
Expand All @@ -61,6 +74,7 @@ public void configure(final Map<String, ?> deserializerConfig,
isDeserializerForRecordKeys);
}

@SuppressWarnings("unchecked")
@Override
public T deserialize(final String topic, final byte[] bytes) {
return (T) inner.deserialize(topic, bytes, schema);
Expand Down
Expand Up @@ -72,11 +72,28 @@ public class ReflectionAvroSerde<T> implements Serde<T> {

private final Serde<T> inner;

public ReflectionAvroSerde() {
inner = Serdes
.serdeFrom(new ReflectionAvroSerializer<>(), new ReflectionAvroDeserializer<>());
}

public ReflectionAvroSerde(Class<T> type) {
inner = Serdes
.serdeFrom(new ReflectionAvroSerializer<>(), new ReflectionAvroDeserializer<>(type));
}

/**
* For testing purposes only.
*/
public ReflectionAvroSerde(final SchemaRegistryClient client) {
if (client == null) {
throw new IllegalArgumentException("schema registry client must not be null");
}
inner = Serdes.serdeFrom(
new ReflectionAvroSerializer<>(client),
new ReflectionAvroDeserializer<>(client));
}

/**
* For testing purposes only.
*/
Expand Down
Expand Up @@ -28,14 +28,13 @@
import java.util.Map;
import org.junit.Test;

public class ReflectionAvroSerdeTest {
public class ReflectionAvroSerdeGenericTest {

private static final String ANY_TOPIC = "any-topic";

private static <T> ReflectionAvroSerde<T>
createConfiguredSerdeForRecordValues(Class<T> type) {
private static <T> ReflectionAvroSerde<T> createConfiguredSerde() {
SchemaRegistryClient schemaRegistryClient = new MockSchemaRegistryClient();
ReflectionAvroSerde<T> serde = new ReflectionAvroSerde<>(schemaRegistryClient, type);
ReflectionAvroSerde<T> serde = new ReflectionAvroSerde<>(schemaRegistryClient);
Map<String, Object> serdeConfig = new HashMap<>();
serdeConfig.put(AbstractKafkaAvroSerDeConfig.SCHEMA_REGISTRY_URL_CONFIG, "fake");
serde.configure(serdeConfig, false);
Expand All @@ -45,7 +44,7 @@ public class ReflectionAvroSerdeTest {
@Test
public void shouldRoundTripRecords() {
// Given
ReflectionAvroSerde<Widget> serde = createConfiguredSerdeForRecordValues(Widget.class);
ReflectionAvroSerde<Widget> serde = createConfiguredSerde();
Widget record = new Widget("alice");

// When
Expand All @@ -63,7 +62,7 @@ public void shouldRoundTripRecords() {
@Test
public void shouldRoundTripNullRecordsToNull() {
// Given
ReflectionAvroSerde<Widget> serde = createConfiguredSerdeForRecordValues(Widget.class);
ReflectionAvroSerde<Widget> serde = createConfiguredSerde();

// When
Widget roundtrippedRecord = serde.deserializer().deserialize(
Expand All @@ -79,7 +78,7 @@ public void shouldRoundTripNullRecordsToNull() {

@Test(expected = IllegalArgumentException.class)
public void shouldFailWhenInstantiatedWithNullSchemaRegistryClient() {
new ReflectionAvroSerde<>(null, Widget.class);
new ReflectionAvroSerde<>((SchemaRegistryClient)null);
}

}
@@ -0,0 +1,95 @@
/*
* Copyright 2018 Confluent Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.confluent.kafka.streams.serdes.avro;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.junit.Assert.assertThat;

import io.confluent.kafka.example.Widget;
import io.confluent.kafka.schemaregistry.client.MockSchemaRegistryClient;
import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient;
import io.confluent.kafka.serializers.AbstractKafkaAvroSerDeConfig;
import java.util.HashMap;
import java.util.Map;
import org.junit.Test;

public class ReflectionAvroSerdeSpecificTest {

private static final String ANY_TOPIC = "any-topic";

private static <T> ReflectionAvroSerde<T>
createConfiguredSerdeForRecordValues(Class<T> type) {
SchemaRegistryClient schemaRegistryClient = new MockSchemaRegistryClient();
ReflectionAvroSerde<T> serde = new ReflectionAvroSerde<>(schemaRegistryClient, type);
Map<String, Object> serdeConfig = new HashMap<>();
serdeConfig.put(AbstractKafkaAvroSerDeConfig.SCHEMA_REGISTRY_URL_CONFIG, "fake");
serde.configure(serdeConfig, false);
return serde;
}

private static <T> ReflectionAvroSerde<T>
createConfiguredSerdeForAnyValues() {
SchemaRegistryClient schemaRegistryClient = new MockSchemaRegistryClient();
ReflectionAvroSerde<T> serde = new ReflectionAvroSerde<>(schemaRegistryClient);
Map<String, Object> serdeConfig = new HashMap<>();
serdeConfig.put(AbstractKafkaAvroSerDeConfig.SCHEMA_REGISTRY_URL_CONFIG, "fake");
serde.configure(serdeConfig, false);
return serde;
}

@Test
public void shouldRoundTripRecords() {
// Given
ReflectionAvroSerde<Widget> serde = createConfiguredSerdeForRecordValues(Widget.class);
Widget record = new Widget("alice");

// When
Widget roundtrippedRecord = serde.deserializer().deserialize(
ANY_TOPIC,
serde.serializer().serialize(ANY_TOPIC, record));

// Then
assertThat(roundtrippedRecord, equalTo(record));

// Cleanup
serde.close();
}

@Test
public void shouldRoundTripNullRecordsToNull() {
// Given
ReflectionAvroSerde<Widget> serde = createConfiguredSerdeForRecordValues(Widget.class);

// When
Widget roundtrippedRecord = serde.deserializer().deserialize(
ANY_TOPIC,
serde.serializer().serialize(ANY_TOPIC, null));

// Then
assertThat(roundtrippedRecord, nullValue());

// Cleanup
serde.close();
}

@Test(expected = IllegalArgumentException.class)
public void shouldFailWhenInstantiatedWithNullSchemaRegistryClient() {
new ReflectionAvroSerde<>(null, Widget.class);
}

}
Expand Up @@ -26,7 +26,7 @@
import org.apache.avro.specific.SpecificDatumReader;
import org.apache.avro.specific.SpecificRecord;
import org.apache.kafka.common.errors.SerializationException;

import org.apache.avro.reflect.ReflectData;
import org.apache.avro.reflect.ReflectDatumReader;

import java.io.IOException;
Expand Down Expand Up @@ -177,58 +177,92 @@ protected GenericContainerWithVersion deserializeWithSchemaAndVersion(
}

protected DatumReader<?> getDatumReader(Schema writerSchema, Schema readerSchema) {
// normalize reader schema
readerSchema = getReaderSchema(writerSchema, readerSchema);
boolean writerSchemaIsPrimitive =
AvroSchemaUtils.getPrimitiveSchemas().values().contains(writerSchema);
// do not use SpecificDatumReader if writerSchema is a primitive
if (useSchemaReflection && !writerSchemaIsPrimitive) {
if (readerSchema == null) {
throw new SerializationException(
"Reader schema cannot be null when using Avro schema reflection");
}
if (writerSchemaIsPrimitive) {
return new GenericDatumReader<>(writerSchema, readerSchema);
} else if (useSchemaReflection) {
return new ReflectDatumReader<>(writerSchema, readerSchema);
} else if (useSpecificAvroReader && !writerSchemaIsPrimitive) {
if (readerSchema == null) {
readerSchema = getReaderSchema(writerSchema);
}
} else if (useSpecificAvroReader) {
return new SpecificDatumReader<>(writerSchema, readerSchema);
} else {
if (readerSchema == null) {
return new GenericDatumReader<>(writerSchema);
}
return new GenericDatumReader<>(writerSchema, readerSchema);
}
}

@SuppressWarnings("unchecked")
private Schema getReaderSchema(Schema writerSchema) {
Schema readerSchema = readerSchemaCache.get(writerSchema.getFullName());
if (readerSchema == null) {
Class<SpecificRecord> readerClass = SpecificData.get().getClass(writerSchema);
if (readerClass != null) {
try {
readerSchema = readerClass.newInstance().getSchema();
} catch (InstantiationException e) {
throw new SerializationException(writerSchema.getFullName()
+ " specified by the "
+ "writers schema could not be instantiated to "
+ "find the readers schema.");
} catch (IllegalAccessException e) {
throw new SerializationException(writerSchema.getFullName()
+ " specified by the "
+ "writers schema is not allowed to be instantiated "
+ "to find the readers schema.");
}
readerSchemaCache.put(writerSchema.getFullName(), readerSchema);
} else {
throw new SerializationException("Could not find class "
+ writerSchema.getFullName()
+ " specified in writer's schema whilst finding reader's "
+ "schema for a SpecificRecord.");
}
/**
* Normalizes the reader schema, puts the resolved schema into the cache.
* <li>
* <ul>if the reader schema is provided, use the provided one</ul>
* <ul>if the reader schema is cached for the writer schema full name, use the cached value</ul>
* <ul>if the writer schema is primitive, use the writer one</ul>
* <ul>if schema reflection is used, generate one from the class referred by writer schema</ul>
* <ul>if generated classes are used, query the class referred by writer schema</ul>
* <ul>otherwise use the writer schema</ul>
* </li>
*/
private Schema getReaderSchema(Schema writerSchema, Schema readerSchema) {
if (readerSchema != null) {
return readerSchema;
}
readerSchema = readerSchemaCache.get(writerSchema.getFullName());
if (readerSchema != null) {
return readerSchema;
}
boolean writerSchemaIsPrimitive =
AvroSchemaUtils.getPrimitiveSchemas().values().contains(writerSchema);
if (writerSchemaIsPrimitive) {
readerSchema = writerSchema;
} else if (useSchemaReflection) {
readerSchema = getReflectionReaderSchema(writerSchema);
readerSchemaCache.put(writerSchema.getFullName(), readerSchema);
} else if (useSpecificAvroReader) {
readerSchema = getSpecificReaderSchema(writerSchema);
readerSchemaCache.put(writerSchema.getFullName(), readerSchema);
} else {
readerSchema = writerSchema;
}
return readerSchema;
}

@SuppressWarnings("unchecked")
private Schema getSpecificReaderSchema(Schema writerSchema) {
Class<SpecificRecord> readerClass = SpecificData.get().getClass(writerSchema);
if (readerClass == null) {
throw new SerializationException("Could not find class "
+ writerSchema.getFullName()
+ " specified in writer's schema whilst finding reader's "
+ "schema for a SpecificRecord.");
}
try {
return readerClass.newInstance().getSchema();
} catch (InstantiationException e) {
throw new SerializationException(writerSchema.getFullName()
+ " specified by the "
+ "writers schema could not be instantiated to "
+ "find the readers schema.");
} catch (IllegalAccessException e) {
throw new SerializationException(writerSchema.getFullName()
+ " specified by the "
+ "writers schema is not allowed to be instantiated "
+ "to find the readers schema.");
}
}

private Schema getReflectionReaderSchema(Schema writerSchema) {
// shall we use ReflectData.AllowNull.get() instead?
Class<?> readerClass = ReflectData.get().getClass(writerSchema);
if (readerClass == null) {
throw new SerializationException("Could not find class "
+ writerSchema.getFullName()
+ " specified in writer's schema whilst finding reader's "
+ "schema for a reflected class.");
}
return ReflectData.get().getSchema(readerClass);
}

class DeserializationContext {
private final String topic;
private final Boolean isKey;
Expand Down

0 comments on commit f8ca0ce

Please sign in to comment.