Skip to content

Commit

Permalink
fix: Refactor ConnectFormatSchemaTranslator to take translator object…
Browse files Browse the repository at this point in the history
… instead of lamda function (#8177)

so that the translator can take in config to decide how to transform field names. e.g. whether it should be uppercased, lowercased or not
  • Loading branch information
lihaosky committed Oct 1, 2021
1 parent 2274da8 commit 6d3b351
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

package io.confluent.ksql.function.udaf.offset;

import static io.confluent.ksql.serde.connect.ConnectSchemaUtil.OPTIONAL_DATE_SCHEMA;
import static io.confluent.ksql.serde.connect.ConnectSchemaUtil.OPTIONAL_TIMESTAMP_SCHEMA;
import static io.confluent.ksql.serde.connect.ConnectSchemaUtil.OPTIONAL_TIME_SCHEMA;
import static io.confluent.ksql.serde.connect.ConnectKsqlSchemaTranslator.OPTIONAL_DATE_SCHEMA;
import static io.confluent.ksql.serde.connect.ConnectKsqlSchemaTranslator.OPTIONAL_TIMESTAMP_SCHEMA;
import static io.confluent.ksql.serde.connect.ConnectKsqlSchemaTranslator.OPTIONAL_TIME_SCHEMA;

import java.util.Comparator;
import org.apache.kafka.connect.data.Schema;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

package io.confluent.ksql.schema.ksql.inference;

import static io.confluent.ksql.serde.connect.ConnectSchemaUtil.OPTIONAL_DATE_SCHEMA;
import static io.confluent.ksql.serde.connect.ConnectSchemaUtil.OPTIONAL_TIMESTAMP_SCHEMA;
import static io.confluent.ksql.serde.connect.ConnectSchemaUtil.OPTIONAL_TIME_SCHEMA;
import static io.confluent.ksql.serde.connect.ConnectKsqlSchemaTranslator.OPTIONAL_DATE_SCHEMA;
import static io.confluent.ksql.serde.connect.ConnectKsqlSchemaTranslator.OPTIONAL_TIMESTAMP_SCHEMA;
import static io.confluent.ksql.serde.connect.ConnectKsqlSchemaTranslator.OPTIONAL_TIME_SCHEMA;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.ArgumentMatchers.any;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public SchemaTranslator getSchemaTranslator(final Map<String, String> formatProp
return new ConnectFormatSchemaTranslator(
this,
formatProperties,
ConnectSchemaUtil::toKsqlSchema
new ConnectKsqlSchemaTranslator()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import io.confluent.ksql.util.KsqlException;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.kafka.connect.data.ConnectSchema;
import org.apache.kafka.connect.data.Field;
Expand All @@ -43,16 +42,16 @@ class ConnectFormatSchemaTranslator implements SchemaTranslator {

private final ConnectFormat format;
private final ConnectSchemaTranslator connectSrTranslator;
private final Function<Schema, Schema> connectKsqlTranslator;
private final ConnectKsqlSchemaTranslator connectKsqlTranslator;

ConnectFormatSchemaTranslator(
final ConnectFormat format,
final Map<String, String> formatProps,
final Function<Schema, Schema> connectKsqlTranslator
final ConnectKsqlSchemaTranslator connectKsqlSchemaTranslator
) {
this.format = requireNonNull(format, "format");
this.connectSrTranslator = requireNonNull(format.getConnectSchemaTranslator(formatProps));
this.connectKsqlTranslator = requireNonNull(connectKsqlTranslator, "toKsqlTransformer");
this.connectKsqlTranslator = requireNonNull(connectKsqlSchemaTranslator);
}

@Override
Expand Down Expand Up @@ -83,7 +82,7 @@ public List<SimpleColumn> toColumns(
+ "=false' in the WITH clause properties.");
}

final Schema rowSchema = connectKsqlTranslator.apply(connectSchema);
final Schema rowSchema = connectKsqlTranslator.toKsqlSchema(connectSchema);

return rowSchema.fields().stream()
.map(ConnectFormatSchemaTranslator::toColumn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import io.confluent.ksql.util.DecimalUtil;
import io.confluent.ksql.util.KsqlException;
import java.util.Map;
import java.util.function.Function;
import java.util.function.BiFunction;
import org.apache.kafka.connect.data.Date;
import org.apache.kafka.connect.data.Field;
import org.apache.kafka.connect.data.Schema;
Expand All @@ -32,42 +32,42 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class ConnectSchemaUtil {
public final class ConnectKsqlSchemaTranslator {

private static final Logger log = LoggerFactory.getLogger(ConnectSchemaTranslator.class);

private static final SqlSchemaFormatter FORMATTER =
new SqlSchemaFormatter(w -> false, Option.AS_COLUMN_LIST);

private static final Map<Type, Function<Schema, Schema>> CONNECT_TO_KSQL =
ImmutableMap.<Type, Function<Schema, Schema>>builder()
.put(Type.INT8, s -> Schema.OPTIONAL_INT32_SCHEMA)
.put(Type.INT16, s -> Schema.OPTIONAL_INT32_SCHEMA)
.put(Type.INT32, ConnectSchemaUtil::convertInt32Schema)
.put(Type.INT64, ConnectSchemaUtil::convertInt64Schema)
.put(Type.FLOAT32, s -> Schema.OPTIONAL_FLOAT64_SCHEMA)
.put(Type.FLOAT64, s -> Schema.OPTIONAL_FLOAT64_SCHEMA)
.put(Type.STRING, s -> Schema.OPTIONAL_STRING_SCHEMA)
.put(Type.BOOLEAN, s -> Schema.OPTIONAL_BOOLEAN_SCHEMA)
.put(Type.BYTES, ConnectSchemaUtil::toKsqlBytesSchema)
.put(Type.ARRAY, ConnectSchemaUtil::toKsqlArraySchema)
.put(Type.MAP, ConnectSchemaUtil::toKsqlMapSchema)
.put(Type.STRUCT, ConnectSchemaUtil::toKsqlStructSchema)
.build();
private final Map<Type, BiFunction<ConnectKsqlSchemaTranslator, Schema, Schema>> connectToSql =
ImmutableMap.<Type, BiFunction<ConnectKsqlSchemaTranslator, Schema, Schema>>builder()
.put(Type.INT8, (instance, s) -> Schema.OPTIONAL_INT32_SCHEMA)
.put(Type.INT16, (instance, s) -> Schema.OPTIONAL_INT32_SCHEMA)
.put(Type.INT32, (instance, s) -> this.convertInt32Schema(s))
.put(Type.INT64, (instance, s) -> instance.convertInt64Schema(s))
.put(Type.FLOAT32, (instance, s) -> Schema.OPTIONAL_FLOAT64_SCHEMA)
.put(Type.FLOAT64, (instance, s) -> Schema.OPTIONAL_FLOAT64_SCHEMA)
.put(Type.STRING, (instance, s) -> Schema.OPTIONAL_STRING_SCHEMA)
.put(Type.BOOLEAN, (instance, s) -> Schema.OPTIONAL_BOOLEAN_SCHEMA)
.put(Type.BYTES, (instance, s) -> instance.toKsqlBytesSchema(s))
.put(Type.ARRAY, (instance, s) -> instance.toKsqlArraySchema(s))
.put(Type.MAP, (instance, s) -> instance.toKsqlMapSchema(s))
.put(Type.STRUCT, (instance, s) -> instance.toKsqlStructSchema(s))
.build();

public static final Schema OPTIONAL_TIMESTAMP_SCHEMA = Timestamp.builder().optional().build();
public static final Schema OPTIONAL_TIME_SCHEMA = Time.builder().optional().build();
public static final Schema OPTIONAL_DATE_SCHEMA = Date.builder().optional().build();

private ConnectSchemaUtil() {
}
public ConnectKsqlSchemaTranslator() {}

/**
* Ensures all schema types are optional.
*
* @param schema the source schema.
* @return the ksql compatible schema.
*/
public static Schema toKsqlSchema(final Schema schema) {
public Schema toKsqlSchema(final Schema schema) {
try {
final Schema rowSchema = toKsqlFieldSchema(schema);
if (rowSchema.type() != Schema.Type.STRUCT) {
Expand All @@ -87,17 +87,18 @@ public static Schema toKsqlSchema(final Schema schema) {
}
}

private static Schema toKsqlFieldSchema(final Schema schema) {
final Function<Schema, Schema> handler = CONNECT_TO_KSQL.get(schema.type());
private Schema toKsqlFieldSchema(final Schema schema) {
final BiFunction<ConnectKsqlSchemaTranslator, Schema, Schema> handler = connectToSql.get(
schema.type());
if (handler == null) {
throw new UnsupportedTypeException(
String.format("Unsupported type: %s", schema.type().getName()));
}

return handler.apply(schema);
return handler.apply(this, schema);
}

private static void checkMapKeyType(final Schema.Type type) {
private void checkMapKeyType(final Schema.Type type) {
switch (type) {
case INT8:
case INT16:
Expand All @@ -111,15 +112,15 @@ private static void checkMapKeyType(final Schema.Type type) {
}
}

private static Schema convertInt64Schema(final Schema schema) {
private Schema convertInt64Schema(final Schema schema) {
if (schema.name() == Timestamp.LOGICAL_NAME) {
return OPTIONAL_TIMESTAMP_SCHEMA;
} else {
return Schema.OPTIONAL_INT64_SCHEMA;
}
}

private static Schema convertInt32Schema(final Schema schema) {
private Schema convertInt32Schema(final Schema schema) {
if (schema.name() == Time.LOGICAL_NAME) {
return OPTIONAL_TIME_SCHEMA;
} else if (schema.name() == Date.LOGICAL_NAME) {
Expand All @@ -129,15 +130,15 @@ private static Schema convertInt32Schema(final Schema schema) {
}
}

private static Schema toKsqlBytesSchema(final Schema schema) {
private Schema toKsqlBytesSchema(final Schema schema) {
if (DecimalUtil.isDecimal(schema)) {
return schema;
} else {
return Schema.OPTIONAL_BYTES_SCHEMA;
}
}

private static Schema toKsqlMapSchema(final Schema schema) {
private Schema toKsqlMapSchema(final Schema schema) {
final Schema keySchema = toKsqlFieldSchema(schema.keySchema());
checkMapKeyType(keySchema.type());
return SchemaBuilder.map(
Expand All @@ -146,13 +147,13 @@ private static Schema toKsqlMapSchema(final Schema schema) {
).optional().build();
}

private static Schema toKsqlArraySchema(final Schema schema) {
private Schema toKsqlArraySchema(final Schema schema) {
return SchemaBuilder.array(
toKsqlFieldSchema(schema.valueSchema())
).optional().build();
}

private static Schema toKsqlStructSchema(final Schema schema) {
private Schema toKsqlStructSchema(final Schema schema) {
final SchemaBuilder schemaBuilder = SchemaBuilder.struct();
for (final Field field : schema.fields()) {
try {
Expand All @@ -166,6 +167,7 @@ private static Schema toKsqlStructSchema(final Schema schema) {
}

private static class UnsupportedTypeException extends RuntimeException {

UnsupportedTypeException(final String error) {
super(error);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
public class ConnectFormatSchemaTranslatorTest {

@Mock
private Function<Schema, Schema> connectKsqlTranslator;
private ConnectKsqlSchemaTranslator connectKsqlTranslator;
@Mock
private ParsedSchema parsedSchema;
@Mock
Expand All @@ -72,7 +72,7 @@ public class ConnectFormatSchemaTranslatorTest {

@Before
public void setUp() {
when(connectKsqlTranslator.apply(any())).thenReturn(transformedSchema);
when(connectKsqlTranslator.toKsqlSchema(any())).thenReturn(transformedSchema);
when(connectSchema.type()).thenReturn(Type.STRUCT);

when(format.getConnectSchemaTranslator(any())).thenReturn(innerTranslator);
Expand All @@ -93,7 +93,7 @@ public void shouldPassConnectSchemaReturnedBySubclassToTranslator() {
translator.toColumns(parsedSchema, SerdeFeatures.of(), false);

// Then:
verify(connectKsqlTranslator).apply(connectSchema);
verify(connectKsqlTranslator).toKsqlSchema(connectSchema);
}

@Test
Expand Down Expand Up @@ -157,7 +157,7 @@ public void shouldSupportBuildingColumnsFromPrimitiveValueSchema() {
translator.toColumns(parsedSchema, SerdeFeatures.of(SerdeFeature.UNWRAP_SINGLES), false);

// Then:
verify(connectKsqlTranslator).apply(SchemaBuilder.struct()
verify(connectKsqlTranslator).toKsqlSchema(SchemaBuilder.struct()
.field("ROWVAL", connectSchema)
.build());
}
Expand All @@ -171,7 +171,7 @@ public void shouldSupportBuildingColumnsFromPrimitiveKeySchema() {
translator.toColumns(parsedSchema, SerdeFeatures.of(SerdeFeature.UNWRAP_SINGLES), true);

// Then:
verify(connectKsqlTranslator).apply(SchemaBuilder.struct()
verify(connectKsqlTranslator).toKsqlSchema(SchemaBuilder.struct()
.field("ROWKEY", connectSchema)
.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

package io.confluent.ksql.serde.connect;

import static io.confluent.ksql.serde.connect.ConnectSchemaUtil.OPTIONAL_DATE_SCHEMA;
import static io.confluent.ksql.serde.connect.ConnectSchemaUtil.OPTIONAL_TIMESTAMP_SCHEMA;
import static io.confluent.ksql.serde.connect.ConnectSchemaUtil.OPTIONAL_TIME_SCHEMA;
import static io.confluent.ksql.serde.connect.ConnectKsqlSchemaTranslator.OPTIONAL_DATE_SCHEMA;
import static io.confluent.ksql.serde.connect.ConnectKsqlSchemaTranslator.OPTIONAL_TIMESTAMP_SCHEMA;
import static io.confluent.ksql.serde.connect.ConnectKsqlSchemaTranslator.OPTIONAL_TIME_SCHEMA;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.notNullValue;
Expand All @@ -31,7 +31,7 @@
import org.junit.Test;


public class ConnectSchemaUtilTest {
public class ConnectKsqlSchemaTranslatorTest {

@Test
public void shouldTranslatePrimitives() {
Expand All @@ -45,7 +45,7 @@ public void shouldTranslatePrimitives() {
.field("bytesField", Schema.BYTES_SCHEMA)
.build();

final Schema ksqlSchema = ConnectSchemaUtil.toKsqlSchema(connectSchema);
final Schema ksqlSchema = new ConnectKsqlSchemaTranslator().toKsqlSchema(connectSchema);
assertThat(ksqlSchema.schema().type(), equalTo(Schema.Type.STRUCT));
assertThat(ksqlSchema.fields().size(), equalTo(connectSchema.fields().size()));
for (int i = 0; i < ksqlSchema.fields().size(); i++) {
Expand All @@ -69,7 +69,7 @@ public void shouldTranslateMaps() {
.field("mapField", SchemaBuilder.map(Schema.STRING_SCHEMA, Schema.INT32_SCHEMA))
.build();

final Schema ksqlSchema = ConnectSchemaUtil.toKsqlSchema(connectSchema);
final Schema ksqlSchema = new ConnectKsqlSchemaTranslator().toKsqlSchema(connectSchema);
assertThat(ksqlSchema.field("MAPFIELD"), notNullValue());
final Schema mapSchema = ksqlSchema.field("MAPFIELD").schema();
assertThat(mapSchema.type(), equalTo(Schema.Type.MAP));
Expand All @@ -90,7 +90,7 @@ public void shouldTranslateStructInsideMap() {
.build()))
.build();

final Schema ksqlSchema = ConnectSchemaUtil.toKsqlSchema(connectSchema);
final Schema ksqlSchema = new ConnectKsqlSchemaTranslator().toKsqlSchema(connectSchema);
assertThat(ksqlSchema.field("MAPFIELD"), notNullValue());
final Schema mapSchema = ksqlSchema.field("MAPFIELD").schema();
assertThat(mapSchema.type(), equalTo(Schema.Type.MAP));
Expand All @@ -109,7 +109,7 @@ public void shouldTranslateArray() {
.field("arrayField", SchemaBuilder.array(Schema.INT32_SCHEMA))
.build();

final Schema ksqlSchema = ConnectSchemaUtil.toKsqlSchema(connectSchema);
final Schema ksqlSchema = new ConnectKsqlSchemaTranslator().toKsqlSchema(connectSchema);
assertThat(ksqlSchema.field("ARRAYFIELD"), notNullValue());
final Schema arraySchema = ksqlSchema.field("ARRAYFIELD").schema();
assertThat(arraySchema.type(), equalTo(Schema.Type.ARRAY));
Expand All @@ -129,7 +129,7 @@ public void shouldTranslateStructInsideArray() {
.build()))
.build();

final Schema ksqlSchema = ConnectSchemaUtil.toKsqlSchema(connectSchema);
final Schema ksqlSchema = new ConnectKsqlSchemaTranslator().toKsqlSchema(connectSchema);
assertThat(ksqlSchema.field("ARRAYFIELD"), notNullValue());
final Schema arraySchema = ksqlSchema.field("ARRAYFIELD").schema();
assertThat(arraySchema.type(), equalTo(Schema.Type.ARRAY));
Expand All @@ -150,7 +150,7 @@ public void shouldTranslateNested() {
.field("structField", connectInnerSchema)
.build();

final Schema ksqlSchema = ConnectSchemaUtil.toKsqlSchema(connectSchema);
final Schema ksqlSchema = new ConnectKsqlSchemaTranslator().toKsqlSchema(connectSchema);
assertThat(ksqlSchema.field("STRUCTFIELD"), notNullValue());
final Schema innerSchema = ksqlSchema.field("STRUCTFIELD").schema();
assertThat(innerSchema.fields().size(), equalTo(connectInnerSchema.fields().size()));
Expand All @@ -172,7 +172,7 @@ public void shouldTranslateMapWithNonStringKey() {
.field("mapfield", SchemaBuilder.map(Schema.INT32_SCHEMA, Schema.INT32_SCHEMA))
.build();

final Schema ksqlSchema = ConnectSchemaUtil.toKsqlSchema(connectSchema);
final Schema ksqlSchema = new ConnectKsqlSchemaTranslator().toKsqlSchema(connectSchema);

assertThat(ksqlSchema.field("MAPFIELD"), notNullValue());
final Schema mapSchema = ksqlSchema.field("MAPFIELD").schema();
Expand All @@ -190,7 +190,7 @@ public void shouldTranslateTimeTypes() {
.field("timestampfield", Timestamp.SCHEMA)
.build();

final Schema ksqlSchema = ConnectSchemaUtil.toKsqlSchema(connectSchema);
final Schema ksqlSchema = new ConnectKsqlSchemaTranslator().toKsqlSchema(connectSchema);

assertThat(ksqlSchema.field("TIMEFIELD").schema(), equalTo(OPTIONAL_TIME_SCHEMA));
assertThat(ksqlSchema.field("DATEFIELD").schema(), equalTo(OPTIONAL_DATE_SCHEMA));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.confluent.ksql.serde.SerdeUtils;
import io.confluent.ksql.serde.connect.ConnectSchemaUtil;
import io.confluent.ksql.serde.connect.ConnectKsqlSchemaTranslator;
import io.confluent.ksql.util.DecimalUtil;
import io.confluent.ksql.util.KsqlException;
import java.io.IOException;
Expand Down Expand Up @@ -103,9 +103,9 @@ public class KsqlJsonDeserializerTest {
.map(Schema.OPTIONAL_STRING_SCHEMA, Schema.OPTIONAL_FLOAT64_SCHEMA)
.optional()
.build())
.field(TIMEFIELD, ConnectSchemaUtil.OPTIONAL_TIME_SCHEMA)
.field(DATEFIELD, ConnectSchemaUtil.OPTIONAL_DATE_SCHEMA)
.field(TIMESTAMPFIELD, ConnectSchemaUtil.OPTIONAL_TIMESTAMP_SCHEMA)
.field(TIMEFIELD, ConnectKsqlSchemaTranslator.OPTIONAL_TIME_SCHEMA)
.field(DATEFIELD, ConnectKsqlSchemaTranslator.OPTIONAL_DATE_SCHEMA)
.field(TIMESTAMPFIELD, ConnectKsqlSchemaTranslator.OPTIONAL_TIMESTAMP_SCHEMA)
.field(BYTESFIELD, Schema.OPTIONAL_BYTES_SCHEMA)
.build();

Expand Down

0 comments on commit 6d3b351

Please sign in to comment.