diff --git a/pom.xml b/pom.xml index 1b29948..d285dad 100644 --- a/pom.xml +++ b/pom.xml @@ -54,6 +54,21 @@ ${flink.version} provided + + + + org.apache.flink + flink-connector-files + ${flink.version} + provided + + + + org.apache.flink + flink-json + ${flink.version} + test + org.apache.flink flink-clients @@ -119,11 +134,23 @@ org.apache.rocketmq rocketmq-broker ${rocketmq.version} + + + guava + com.google.guava + + org.apache.rocketmq rocketmq-test ${rocketmq.version} + + + guava + com.google.guava + + org.apache.rocketmq diff --git a/src/main/java/org/apache/rocketmq/flink/common/RocketMQOptions.java b/src/main/java/org/apache/rocketmq/flink/common/RocketMQOptions.java index 50a0883..2931c6a 100644 --- a/src/main/java/org/apache/rocketmq/flink/common/RocketMQOptions.java +++ b/src/main/java/org/apache/rocketmq/flink/common/RocketMQOptions.java @@ -20,7 +20,11 @@ import org.apache.flink.configuration.ConfigOption; import org.apache.flink.configuration.ConfigOptions; +import org.apache.flink.configuration.description.Description; +import java.util.List; + +import static org.apache.flink.table.factories.FactoryUtil.FORMAT_SUFFIX; import static org.apache.rocketmq.flink.legacy.RocketMQConfig.DEFAULT_START_MESSAGE_OFFSET; /** Includes config options of RocketMQ connector type. */ @@ -117,4 +121,83 @@ public class RocketMQOptions { public static final ConfigOption OPTIONAL_OFFSET_FROM_TIMESTAMP = ConfigOptions.key("offsetFromTimestamp").longType().noDefaultValue(); + + // -------------------------------------------------------------------------------------------- + // Format options + // -------------------------------------------------------------------------------------------- + + public static final ConfigOption VALUE_FORMAT = + ConfigOptions.key("value" + FORMAT_SUFFIX) + .stringType() + .noDefaultValue() + .withDescription( + "Defines the format identifier for encoding value data. " + + "The identifier is used to discover a suitable format factory."); + + public static final ConfigOption KEY_FORMAT = + ConfigOptions.key("key" + FORMAT_SUFFIX) + .stringType() + // .defaultValue("rocketmq-default") + .noDefaultValue() + .withDescription( + "Defines the format identifier for encoding key data. " + + "The identifier is used to discover a suitable format factory."); + + public static final ConfigOption> KEY_FIELDS = + ConfigOptions.key("key.fields") + .stringType() + .asList() + .defaultValues() + .withDescription( + "Defines an explicit list of physical columns from the table schema " + + "that configure the data type for the key format. By default, this list is " + + "empty and thus a key is undefined."); + + public static final ConfigOption VALUE_FIELDS_INCLUDE = + ConfigOptions.key("value.fields-include") + .enumType(ValueFieldsStrategy.class) + .defaultValue(ValueFieldsStrategy.ALL) + .withDescription( + String.format( + "Defines a strategy how to deal with key columns in the data type " + + "of the value format. By default, '%s' physical columns of the table schema " + + "will be included in the value format which means that the key columns " + + "appear in the data type for both the key and value format.", + ValueFieldsStrategy.ALL)); + + public static final ConfigOption KEY_FIELDS_PREFIX = + ConfigOptions.key("key.fields-prefix") + .stringType() + .noDefaultValue() + .withDescription( + Description.builder() + .text( + "Defines a custom prefix for all fields of the key format to avoid " + + "name clashes with fields of the value format. " + + "By default, the prefix is empty.") + .linebreak() + .text( + String.format( + "If a custom prefix is defined, both the table schema and '%s' will work with prefixed names.", + KEY_FIELDS.key())) + .linebreak() + .text( + "When constructing the data type of the key format, the prefix " + + "will be removed and the non-prefixed names will be used within the key format.") + .linebreak() + .text( + String.format( + "Please note that this option requires that '%s' must be '%s'.", + VALUE_FIELDS_INCLUDE.key(), + ValueFieldsStrategy.EXCEPT_KEY)) + .build()); + + // -------------------------------------------------------------------------------------------- + // Enums + // -------------------------------------------------------------------------------------------- + + public enum ValueFieldsStrategy { + ALL, + EXCEPT_KEY + } } diff --git a/src/main/java/org/apache/rocketmq/flink/legacy/common/serialization/SimpleKeyValueDeserializationSchema.java b/src/main/java/org/apache/rocketmq/flink/legacy/common/serialization/SimpleKeyValueDeserializationSchema.java index 1177f76..f1ce516 100644 --- a/src/main/java/org/apache/rocketmq/flink/legacy/common/serialization/SimpleKeyValueDeserializationSchema.java +++ b/src/main/java/org/apache/rocketmq/flink/legacy/common/serialization/SimpleKeyValueDeserializationSchema.java @@ -17,13 +17,14 @@ package org.apache.rocketmq.flink.legacy.common.serialization; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.MapTypeInfo; import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; -import org.apache.flink.api.java.typeutils.MapTypeInfo; -public class SimpleKeyValueDeserializationSchema implements KeyValueDeserializationSchema> { +public class SimpleKeyValueDeserializationSchema + implements KeyValueDeserializationSchema> { public static final String DEFAULT_KEY_FIELD = "key"; public static final String DEFAULT_VALUE_FIELD = "value"; diff --git a/src/main/java/org/apache/rocketmq/flink/source/RocketMQSource.java b/src/main/java/org/apache/rocketmq/flink/source/RocketMQSource.java index 8d98d2e..addcbf5 100644 --- a/src/main/java/org/apache/rocketmq/flink/source/RocketMQSource.java +++ b/src/main/java/org/apache/rocketmq/flink/source/RocketMQSource.java @@ -29,7 +29,6 @@ import org.apache.rocketmq.flink.source.split.RocketMQPartitionSplit; import org.apache.rocketmq.flink.source.split.RocketMQPartitionSplitSerializer; -import org.apache.flink.api.common.serialization.DeserializationSchema; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.connector.source.Boundedness; import org.apache.flink.api.connector.source.Source; @@ -123,12 +122,13 @@ public Boundedness getBoundedness() { } @Override - public SourceReader createReader( - SourceReaderContext readerContext) { + public SourceReader createReader(SourceReaderContext readerContext) + throws Exception { FutureCompletingBlockingQueue>> elementsQueue = new FutureCompletingBlockingQueue<>(); deserializationSchema.open( - new DeserializationSchema.InitializationContext() { + new org.apache.flink.api.common.serialization.DeserializationSchema + .InitializationContext() { @Override public MetricGroup getMetricGroup() { return readerContext.metricGroup(); diff --git a/src/main/java/org/apache/rocketmq/flink/source/reader/deserializer/RocketMQDeserializationSchema.java b/src/main/java/org/apache/rocketmq/flink/source/reader/deserializer/RocketMQDeserializationSchema.java index e50b702..8353310 100644 --- a/src/main/java/org/apache/rocketmq/flink/source/reader/deserializer/RocketMQDeserializationSchema.java +++ b/src/main/java/org/apache/rocketmq/flink/source/reader/deserializer/RocketMQDeserializationSchema.java @@ -42,7 +42,7 @@ public interface RocketMQDeserializationSchema */ @Override @PublicEvolving - default void open(InitializationContext context) {} + default void open(InitializationContext context) throws Exception {} /** * Deserializes the byte message. diff --git a/src/main/java/org/apache/rocketmq/flink/source/reader/deserializer/RocketMQRowDeserializationSchema.java b/src/main/java/org/apache/rocketmq/flink/source/reader/deserializer/RocketMQRowDeserializationSchema.java index 5bd990e..c87edbc 100644 --- a/src/main/java/org/apache/rocketmq/flink/source/reader/deserializer/RocketMQRowDeserializationSchema.java +++ b/src/main/java/org/apache/rocketmq/flink/source/reader/deserializer/RocketMQRowDeserializationSchema.java @@ -28,12 +28,13 @@ import org.apache.flink.table.data.RowData; import org.apache.flink.util.Collector; +import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; /** - * A row data wrapper class that wraps a {@link RocketMQDeserializationSchema} to deserialize {@link + * A row data wrapper class that wraps a {@link DeserializationSchema} to deserialize {@link * MessageExt}. */ public class RocketMQRowDeserializationSchema implements RocketMQDeserializationSchema { @@ -46,6 +47,10 @@ public class RocketMQRowDeserializationSchema implements RocketMQDeserialization public RocketMQRowDeserializationSchema( TableSchema tableSchema, + org.apache.flink.api.common.serialization.DeserializationSchema + keyDeserialization, + org.apache.flink.api.common.serialization.DeserializationSchema + valueDeserialization, Map properties, boolean hasMetadata, MetadataConverter[] metadataConverters) { @@ -55,17 +60,20 @@ public RocketMQRowDeserializationSchema( .setTableSchema(tableSchema) .setHasMetadata(hasMetadata) .setMetadataConverters(metadataConverters) + .setKeyDeserialization(keyDeserialization) + .setValueDeserialization(valueDeserialization) .build(); } @Override - public void open(InitializationContext context) { + public void open(InitializationContext context) throws Exception { deserializationSchema.open(context); bytesMessages = new ArrayList<>(); } @Override - public void deserialize(List input, Collector collector) { + public void deserialize(List input, Collector collector) + throws IOException { extractMessages(input); deserializationSchema.deserialize(bytesMessages, collector); } diff --git a/src/main/java/org/apache/rocketmq/flink/source/reader/deserializer/RowDeserializationSchema.java b/src/main/java/org/apache/rocketmq/flink/source/reader/deserializer/RowDeserializationSchema.java index 8beaaa2..bb47f15 100644 --- a/src/main/java/org/apache/rocketmq/flink/source/reader/deserializer/RowDeserializationSchema.java +++ b/src/main/java/org/apache/rocketmq/flink/source/reader/deserializer/RowDeserializationSchema.java @@ -42,8 +42,12 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.Nullable; + +import java.io.IOException; import java.io.Serializable; import java.io.UnsupportedEncodingException; +import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -60,6 +64,16 @@ public class RowDeserializationSchema private static final Logger LOGGER = LoggerFactory.getLogger(RowDeserializationSchema.class); private transient TableSchema tableSchema; + + private final @Nullable org.apache.flink.api.common.serialization.DeserializationSchema + keyDeserialization; + + private final @Nullable org.apache.flink.api.common.serialization.DeserializationSchema + valueDeserialization; + + /** Indices that determine the key fields and the target position in the produced row. */ + private final int[] keyProjection; + private final DirtyDataStrategy formatErrorStrategy; private final DirtyDataStrategy fieldMissingStrategy; private final DirtyDataStrategy fieldIncrementStrategy; @@ -67,6 +81,8 @@ public class RowDeserializationSchema private final String fieldDelimiter; private final String lineDelimiter; private final boolean columnErrorDebug; + + private final BufferingCollector keyCollector; private final MetadataCollector metadataCollector; private final int totalColumnSize; private final int dataColumnSize; @@ -83,6 +99,12 @@ public class RowDeserializationSchema public RowDeserializationSchema( TableSchema tableSchema, + @Nullable + org.apache.flink.api.common.serialization.DeserializationSchema + keyDeserialization, + org.apache.flink.api.common.serialization.DeserializationSchema + valueDeserialization, + int[] keyProjection, DirtyDataStrategy formatErrorStrategy, DirtyDataStrategy fieldMissingStrategy, DirtyDataStrategy fieldIncrementStrategy, @@ -95,6 +117,9 @@ public RowDeserializationSchema( List headerFields, Map properties) { this.tableSchema = tableSchema; + this.keyDeserialization = keyDeserialization; + this.valueDeserialization = valueDeserialization; + this.keyProjection = keyProjection; this.formatErrorStrategy = formatErrorStrategy; this.fieldMissingStrategy = fieldMissingStrategy; this.fieldIncrementStrategy = fieldIncrementStrategy; @@ -102,6 +127,7 @@ public RowDeserializationSchema( this.encoding = encoding; this.fieldDelimiter = StringEscapeUtils.unescapeJava(fieldDelimiter); this.lineDelimiter = StringEscapeUtils.unescapeJava(lineDelimiter); + this.keyCollector = new BufferingCollector(); this.metadataCollector = new MetadataCollector(hasMetadata, metadataConverters); this.headerFields = headerFields == null ? null : new HashSet<>(headerFields); this.properties = properties; @@ -126,19 +152,42 @@ public RowDeserializationSchema( } @Override - public void open(InitializationContext context) { + public void open(InitializationContext context) throws Exception { DescriptorProperties descriptorProperties = new DescriptorProperties(); descriptorProperties.putProperties(properties); this.tableSchema = SchemaValidator.deriveTableSinkSchema(descriptorProperties); this.fieldDataTypes = tableSchema.getFieldDataTypes(); this.lastLogExceptionTime = System.currentTimeMillis(); this.lastLogHandleFieldTime = System.currentTimeMillis(); + + if (keyDeserialization != null) { + keyDeserialization.open(context); + } + if (valueDeserialization != null) { + valueDeserialization.open(context); + } } @Override - public void deserialize(List messages, Collector collector) { + public void deserialize(List messages, Collector collector) + throws IOException { metadataCollector.collector = collector; - deserialize(messages, metadataCollector); + + if (keyDeserialization == null && valueDeserialization == null) { + // Use default deserializer + deserialize(messages, metadataCollector); + return; + } + + if (keyDeserialization == null && !metadataCollector.hasMetadata) { + for (BytesMessage message : messages) { + valueDeserialization.deserialize(message.getData(), collector); + } + } else { + // TODO Implement key deserialization + LOGGER.error("keyDeserialization not support yet"); + throw new RuntimeException("keyDeserialization not support yet"); + } } private void deserialize(List messages, MetadataCollector collector) { @@ -405,11 +454,31 @@ public TypeInformation getProducedType() { /** Source metadata converter interface. */ public interface MetadataConverter extends Serializable { + Object read(BytesMessage message); } // -------------------------------------------------------------------------------------------- + private static final class BufferingCollector implements Collector, Serializable { + + private static final long serialVersionUID = 1L; + + private final List buffer = new ArrayList<>(); + + @Override + public void collect(RowData record) { + buffer.add(record); + } + + @Override + public void close() { + // nothing to do + } + } + + // -------------------------------------------------------------------------------------------- + /** Metadata of RowData collector. */ public static final class MetadataCollector implements Collector, Serializable { @@ -457,6 +526,12 @@ public void close() { public static class Builder { private TableSchema schema; + private org.apache.flink.api.common.serialization.DeserializationSchema + keyDeserialization; + + private org.apache.flink.api.common.serialization.DeserializationSchema + valueDeserialization; + private int[] keyProjection; private DirtyDataStrategy formatErrorStrategy = DirtyDataStrategy.SKIP; private DirtyDataStrategy fieldMissingStrategy = DirtyDataStrategy.SKIP; private DirtyDataStrategy fieldIncrementStrategy = DirtyDataStrategy.CUT; @@ -476,6 +551,20 @@ public Builder setTableSchema(TableSchema tableSchema) { return this; } + public Builder setKeyDeserialization( + org.apache.flink.api.common.serialization.DeserializationSchema + keyDeserialization) { + this.keyDeserialization = keyDeserialization; + return this; + } + + public Builder setValueDeserialization( + org.apache.flink.api.common.serialization.DeserializationSchema + valueDeserialization) { + this.valueDeserialization = valueDeserialization; + return this; + } + public Builder setFormatErrorStrategy(DirtyDataStrategy formatErrorStrategy) { this.formatErrorStrategy = formatErrorStrategy; return this; @@ -491,6 +580,11 @@ public Builder setFieldIncrementStrategy(DirtyDataStrategy fieldIncrementStrateg return this; } + public Builder setKeyProjection(int[] keyProjection) { + this.keyProjection = keyProjection; + return this; + } + public Builder setEncoding(String encoding) { this.encoding = encoding; return this; @@ -577,6 +671,9 @@ public Builder setProperties(Map properties) { public RowDeserializationSchema build() { return new RowDeserializationSchema( schema, + keyDeserialization, + valueDeserialization, + keyProjection, formatErrorStrategy, fieldMissingStrategy, fieldIncrementStrategy, @@ -593,6 +690,7 @@ public RowDeserializationSchema build() { /** Options for {@link RowDeserializationSchema}. */ public static class CollectorOption { + public static final ConfigOption ENCODING = ConfigOptions.key("encoding".toLowerCase()).defaultValue("UTF-8"); public static final ConfigOption FIELD_DELIMITER = diff --git a/src/main/java/org/apache/rocketmq/flink/source/table/RocketMQConnectorOptionsUtil.java b/src/main/java/org/apache/rocketmq/flink/source/table/RocketMQConnectorOptionsUtil.java new file mode 100644 index 0000000..c2cefec --- /dev/null +++ b/src/main/java/org/apache/rocketmq/flink/source/table/RocketMQConnectorOptionsUtil.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.rocketmq.flink.source.table; + +import org.apache.rocketmq.flink.common.RocketMQOptions; +import org.apache.rocketmq.flink.common.RocketMQOptions.ValueFieldsStrategy; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.configuration.ReadableConfig; +import org.apache.flink.table.api.TableException; +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.flink.table.types.logical.utils.LogicalTypeChecks; +import org.apache.flink.util.Preconditions; + +import java.util.List; +import java.util.Optional; +import java.util.stream.IntStream; + +import static org.apache.rocketmq.flink.common.RocketMQOptions.KEY_FIELDS; +import static org.apache.rocketmq.flink.common.RocketMQOptions.KEY_FIELDS_PREFIX; +import static org.apache.rocketmq.flink.common.RocketMQOptions.KEY_FORMAT; +import static org.apache.rocketmq.flink.common.RocketMQOptions.VALUE_FIELDS_INCLUDE; + +/** Utilities for {@link RocketMQOptions}. */ +@Internal +class RocketMQConnectorOptionsUtil { + + /** + * Creates an array of indices that determine which physical fields of the table schema to + * include in the key format and the order that those fields have in the key format. + * + *

See {@link RocketMQOptions#KEY_FORMAT}, {@link RocketMQOptions#KEY_FIELDS}, and {@link + * RocketMQOptions#KEY_FIELDS_PREFIX} for more information. + */ + public static int[] createKeyFormatProjection( + ReadableConfig options, DataType physicalDataType) { + final LogicalType physicalType = physicalDataType.getLogicalType(); + Preconditions.checkArgument( + physicalType.is(LogicalTypeRoot.ROW), "Row data type expected."); + final Optional optionalKeyFormat = options.getOptional(KEY_FORMAT); + final Optional> optionalKeyFields = options.getOptional(KEY_FIELDS); + + if (!optionalKeyFormat.isPresent() && optionalKeyFields.isPresent()) { + throw new ValidationException( + String.format( + "The option '%s' can only be declared if a key format is defined using '%s'.", + KEY_FIELDS.key(), KEY_FORMAT.key())); + } else if (optionalKeyFormat.isPresent() + && (!optionalKeyFields.isPresent() || optionalKeyFields.get().size() == 0)) { + throw new ValidationException( + String.format( + "A key format '%s' requires the declaration of one or more of key fields using '%s'.", + KEY_FORMAT.key(), KEY_FIELDS.key())); + } + + if (!optionalKeyFormat.isPresent()) { + return new int[0]; + } + + final String keyPrefix = options.getOptional(KEY_FIELDS_PREFIX).orElse(""); + + final List keyFields = optionalKeyFields.get(); + final List physicalFields = LogicalTypeChecks.getFieldNames(physicalType); + return keyFields.stream() + .mapToInt( + keyField -> { + final int pos = physicalFields.indexOf(keyField); + // check that field name exists + if (pos < 0) { + throw new ValidationException( + String.format( + "Could not find the field '%s' in the table schema for usage in the key format. " + + "A key field must be a regular, physical column. " + + "The following columns can be selected in the '%s' option:\n" + + "%s", + keyField, KEY_FIELDS.key(), physicalFields)); + } + // check that field name is prefixed correctly + if (!keyField.startsWith(keyPrefix)) { + throw new ValidationException( + String.format( + "All fields in '%s' must be prefixed with '%s' when option '%s' " + + "is set but field '%s' is not prefixed.", + KEY_FIELDS.key(), + keyPrefix, + KEY_FIELDS_PREFIX.key(), + keyField)); + } + return pos; + }) + .toArray(); + } + /** + * Creates an array of indices that determine which physical fields of the table schema to + * include in the value format. + * + *

See {@link RocketMQOptions#VALUE_FORMAT}, {@link RocketMQOptions#VALUE_FIELDS_INCLUDE}, + * and {@link RocketMQOptions#KEY_FIELDS_PREFIX} for more information. + */ + public static int[] createValueFormatProjection( + ReadableConfig options, DataType physicalDataType) { + final LogicalType physicalType = physicalDataType.getLogicalType(); + Preconditions.checkArgument( + physicalType.is(LogicalTypeRoot.ROW), "Row data type expected."); + final int physicalFieldCount = LogicalTypeChecks.getFieldCount(physicalType); + final IntStream physicalFields = IntStream.range(0, physicalFieldCount); + + final String keyPrefix = options.getOptional(KEY_FIELDS_PREFIX).orElse(""); + + final ValueFieldsStrategy strategy = options.get(VALUE_FIELDS_INCLUDE); + if (strategy == ValueFieldsStrategy.ALL) { + if (keyPrefix.length() > 0) { + throw new ValidationException( + String.format( + "A key prefix is not allowed when option '%s' is set to '%s'. " + + "Set it to '%s' instead to avoid field overlaps.", + VALUE_FIELDS_INCLUDE.key(), + ValueFieldsStrategy.ALL, + ValueFieldsStrategy.EXCEPT_KEY)); + } + return physicalFields.toArray(); + } else if (strategy == ValueFieldsStrategy.EXCEPT_KEY) { + final int[] keyProjection = createKeyFormatProjection(options, physicalDataType); + return physicalFields + .filter(pos -> IntStream.of(keyProjection).noneMatch(k -> k == pos)) + .toArray(); + } + throw new TableException("Unknown value fields strategy:" + strategy); + } + + private RocketMQConnectorOptionsUtil() {} +} diff --git a/src/main/java/org/apache/rocketmq/flink/source/table/RocketMQDynamicTableSourceFactory.java b/src/main/java/org/apache/rocketmq/flink/source/table/RocketMQDynamicTableSourceFactory.java index 8b4fd52..c0f8e3b 100644 --- a/src/main/java/org/apache/rocketmq/flink/source/table/RocketMQDynamicTableSourceFactory.java +++ b/src/main/java/org/apache/rocketmq/flink/source/table/RocketMQDynamicTableSourceFactory.java @@ -20,14 +20,23 @@ import org.apache.rocketmq.flink.common.RocketMQOptions; +import org.apache.flink.api.common.serialization.DeserializationSchema; import org.apache.flink.configuration.ConfigOption; import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.ReadableConfig; import org.apache.flink.table.api.TableSchema; +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.connector.format.DecodingFormat; import org.apache.flink.table.connector.source.DynamicTableSource; +import org.apache.flink.table.data.RowData; import org.apache.flink.table.descriptors.DescriptorProperties; +import org.apache.flink.table.factories.DeserializationFormatFactory; import org.apache.flink.table.factories.DynamicTableSourceFactory; import org.apache.flink.table.factories.FactoryUtil; +import org.apache.flink.table.factories.FactoryUtil.TableFactoryHelper; +import org.apache.flink.table.types.DataType; import org.apache.flink.table.utils.TableSchemaUtils; +import org.apache.flink.types.RowKind; import org.apache.flink.util.Preconditions; import org.apache.flink.util.StringUtils; @@ -36,11 +45,14 @@ import java.text.ParseException; import java.util.HashSet; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.TimeZone; +import static org.apache.flink.table.factories.FactoryUtil.FORMAT; import static org.apache.flink.table.factories.FactoryUtil.createTableFactoryHelper; import static org.apache.rocketmq.flink.common.RocketMQOptions.CONSUMER_GROUP; +import static org.apache.rocketmq.flink.common.RocketMQOptions.KEY_FORMAT; import static org.apache.rocketmq.flink.common.RocketMQOptions.NAME_SERVER_ADDRESS; import static org.apache.rocketmq.flink.common.RocketMQOptions.OPTIONAL_ACCESS_KEY; import static org.apache.rocketmq.flink.common.RocketMQOptions.OPTIONAL_COLUMN_ERROR_DEBUG; @@ -61,7 +73,10 @@ import static org.apache.rocketmq.flink.common.RocketMQOptions.OPTIONAL_TIME_ZONE; import static org.apache.rocketmq.flink.common.RocketMQOptions.OPTIONAL_USE_NEW_API; import static org.apache.rocketmq.flink.common.RocketMQOptions.TOPIC; +import static org.apache.rocketmq.flink.common.RocketMQOptions.VALUE_FORMAT; import static org.apache.rocketmq.flink.legacy.RocketMQConfig.CONSUMER_OFFSET_LATEST; +import static org.apache.rocketmq.flink.source.table.RocketMQConnectorOptionsUtil.createKeyFormatProjection; +import static org.apache.rocketmq.flink.source.table.RocketMQConnectorOptionsUtil.createValueFormatProjection; /** * Defines the {@link DynamicTableSourceFactory} implementation to create {@link @@ -106,12 +121,19 @@ public Set> optionalOptions() { optionalOptions.add(OPTIONAL_SECRET_KEY); optionalOptions.add(OPTIONAL_SCAN_STARTUP_MODE); optionalOptions.add(OPTIONAL_CONSUMER_POLL_MS); + optionalOptions.add(FORMAT); + optionalOptions.add(KEY_FORMAT); + optionalOptions.add(VALUE_FORMAT); return optionalOptions; } @Override public DynamicTableSource createDynamicTableSource(Context context) { FactoryUtil.TableFactoryHelper helper = createTableFactoryHelper(this, context); + final Optional>> keyDecodingFormat = + getKeyDecodingFormat(helper); + final Optional>> valueDecodingFormat = + getValueDecodingFormat(helper); helper.validate(); Map rawProperties = context.getCatalogTable().getOptions(); Configuration configuration = Configuration.fromMap(rawProperties); @@ -183,10 +205,24 @@ public DynamicTableSource createDynamicTableSource(Context context) { long consumerOffsetTimestamp = configuration.getLong( RocketMQOptions.OPTIONAL_OFFSET_FROM_TIMESTAMP, System.currentTimeMillis()); + + final ReadableConfig tableOptions = helper.getOptions(); + + final DataType physicalDataType = + context.getCatalogTable().getResolvedSchema().toPhysicalRowDataType(); + final int[] keyProjection = createKeyFormatProjection(tableOptions, physicalDataType); + + final int[] valueProjection = createValueFormatProjection(tableOptions, physicalDataType); + return new RocketMQScanTableSource( configuration.getLong(OPTIONAL_CONSUMER_POLL_MS), descriptorProperties, physicalSchema, + keyDecodingFormat.orElse(null), + valueDecodingFormat.orElse(null), + keyProjection, + valueProjection, + physicalDataType, topic, consumerGroup, nameServerAddress, @@ -208,4 +244,37 @@ private Long parseDateString(String dateString, String timeZone) throws ParseExc FastDateFormat.getInstance(DATE_FORMAT, TimeZone.getTimeZone(timeZone)); return simpleDateFormat.parse(dateString).getTime(); } + + private static Optional>> getValueDecodingFormat( + TableFactoryHelper helper) { + Optional>> + deserializationSchemaDecodingFormat = + helper.discoverOptionalDecodingFormat( + DeserializationFormatFactory.class, FORMAT); + if (deserializationSchemaDecodingFormat.isPresent()) { + return deserializationSchemaDecodingFormat; + } + + return helper.discoverOptionalDecodingFormat( + DeserializationFormatFactory.class, VALUE_FORMAT); + } + + private static Optional>> getKeyDecodingFormat( + TableFactoryHelper helper) { + final Optional>> keyDecodingFormat = + helper.discoverOptionalDecodingFormat( + DeserializationFormatFactory.class, KEY_FORMAT); + keyDecodingFormat.ifPresent( + format -> { + if (!format.getChangelogMode().containsOnly(RowKind.INSERT)) { + throw new ValidationException( + String.format( + "A key format should only deal with INSERT-only records. " + + "But %s has a changelog mode of %s.", + helper.getOptions().get(KEY_FORMAT), + format.getChangelogMode())); + } + }); + return keyDecodingFormat; + } } diff --git a/src/main/java/org/apache/rocketmq/flink/source/table/RocketMQScanTableSource.java b/src/main/java/org/apache/rocketmq/flink/source/table/RocketMQScanTableSource.java index 3eb68df..d549721 100644 --- a/src/main/java/org/apache/rocketmq/flink/source/table/RocketMQScanTableSource.java +++ b/src/main/java/org/apache/rocketmq/flink/source/table/RocketMQScanTableSource.java @@ -27,9 +27,13 @@ import org.apache.rocketmq.flink.source.reader.deserializer.RocketMQRowDeserializationSchema; import org.apache.rocketmq.flink.source.reader.deserializer.RowDeserializationSchema.MetadataConverter; +import org.apache.flink.api.common.serialization.DeserializationSchema; +import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.api.TableSchema; import org.apache.flink.table.connector.ChangelogMode; +import org.apache.flink.table.connector.Projection; +import org.apache.flink.table.connector.format.DecodingFormat; import org.apache.flink.table.connector.source.DynamicTableSource; import org.apache.flink.table.connector.source.ScanTableSource; import org.apache.flink.table.connector.source.SourceFunctionProvider; @@ -39,6 +43,9 @@ import org.apache.flink.table.data.StringData; import org.apache.flink.table.descriptors.DescriptorProperties; import org.apache.flink.table.types.DataType; +import org.apache.flink.util.Preconditions; + +import javax.annotation.Nullable; import java.util.Collections; import java.util.LinkedHashMap; @@ -56,6 +63,25 @@ public class RocketMQScanTableSource implements ScanTableSource, SupportsReading private final DescriptorProperties properties; private final TableSchema schema; + /** Data type that describes the final output of the source. */ + protected DataType producedDataType; + + /** Optional format for decoding keys from Kafka. */ + private final @Nullable DecodingFormat< + org.apache.flink.api.common.serialization.DeserializationSchema> + keyDecodingFormat; + + /** Format for decoding values from Kafka. */ + private final @Nullable DecodingFormat< + org.apache.flink.api.common.serialization.DeserializationSchema> + valueDecodingFormat; + + /** Indices that determine the key fields and the target position in the produced row. */ + protected final int[] keyProjection; + + /** Indices that determine the value fields and the target position in the produced row. */ + protected final int[] valueProjection; + private final String consumerOffsetMode; private final long consumerOffsetTimestamp; @@ -81,6 +107,16 @@ public RocketMQScanTableSource( long pollTime, DescriptorProperties properties, TableSchema schema, + @Nullable + DecodingFormat< + org.apache.flink.api.common.serialization.DeserializationSchema< + RowData>> + keyDecodingFormat, + DecodingFormat> + valueDecodingFormat, + int[] keyProjection, + int[] valueProjection, + DataType physicalDataType, String topic, String consumerGroup, String nameServerAddress, @@ -98,6 +134,13 @@ public RocketMQScanTableSource( this.pollTime = pollTime; this.properties = properties; this.schema = schema; + this.keyDecodingFormat = keyDecodingFormat; + this.valueDecodingFormat = valueDecodingFormat; + this.keyProjection = + Preconditions.checkNotNull(keyProjection, "Key projection must not be null."); + this.valueProjection = + Preconditions.checkNotNull(valueProjection, "Value projection must not be null."); + this.producedDataType = physicalDataType; this.topic = topic; this.consumerGroup = consumerGroup; this.nameServerAddress = nameServerAddress; @@ -123,7 +166,19 @@ public ChangelogMode getChangelogMode() { @Override public ScanRuntimeProvider getScanRuntimeProvider(ScanContext scanContext) { if (useNewApi) { - return SourceProvider.of( + final org.apache.flink.api.common.serialization.DeserializationSchema + keyDeserialization = + createDeserialization(scanContext, keyDecodingFormat, keyProjection); + + final org.apache.flink.api.common.serialization.DeserializationSchema + valueDeserialization = + createDeserialization( + scanContext, valueDecodingFormat, valueProjection); + + final TypeInformation producedTypeInfo = + scanContext.createTypeInformation(producedDataType); + + RocketMQSource rocketMQSource = new RocketMQSource<>( pollTime, topic, @@ -138,9 +193,11 @@ public ScanRuntimeProvider getScanRuntimeProvider(ScanContext scanContext) { startMessageOffset < 0 ? 0 : startMessageOffset, partitionDiscoveryIntervalMs, isBounded() ? BOUNDED : CONTINUOUS_UNBOUNDED, - createRocketMQDeserializationSchema(), + createRocketMQDeserializationSchema( + keyDeserialization, valueDeserialization, producedTypeInfo), consumerOffsetMode, - consumerOffsetTimestamp)); + consumerOffsetTimestamp); + return SourceProvider.of(rocketMQSource); } else { return SourceFunctionProvider.of( new RocketMQSourceFunction<>( @@ -169,6 +226,11 @@ public DynamicTableSource copy() { pollTime, properties, schema, + keyDecodingFormat, + valueDecodingFormat, + keyProjection, + valueProjection, + producedDataType, topic, consumerGroup, nameServerAddress, @@ -192,7 +254,11 @@ public String asSummaryString() { return RocketMQScanTableSource.class.getName(); } - private RocketMQDeserializationSchema createRocketMQDeserializationSchema() { + private RocketMQDeserializationSchema createRocketMQDeserializationSchema( + DeserializationSchema keyDeserialization, + DeserializationSchema valueDeserialization, + TypeInformation producedTypeInfo) { + final MetadataConverter[] metadataConverters = metadataKeys.stream() .map( @@ -203,8 +269,14 @@ private RocketMQDeserializationSchema createRocketMQDeserializationSche .orElseThrow(IllegalStateException::new)) .map(m -> m.converter) .toArray(MetadataConverter[]::new); + return new RocketMQRowDeserializationSchema( - schema, properties.asMap(), metadataKeys.size() > 0, metadataConverters); + schema, + keyDeserialization, + valueDeserialization, + properties.asMap(), + metadataKeys.size() > 0, + metadataConverters); } private boolean isBounded() { @@ -262,4 +334,22 @@ public Object read(BytesMessage message) { this.converter = converter; } } + + private @Nullable org.apache.flink.api.common.serialization.DeserializationSchema + createDeserialization( + DynamicTableSource.Context context, + @Nullable + DecodingFormat< + org.apache.flink.api.common.serialization + .DeserializationSchema< + RowData>> + format, + int[] projection) { + if (format == null) { + return null; + } + + DataType physicalFormatDataType = Projection.of(projection).project(this.producedDataType); + return format.createRuntimeDecoder(context, physicalFormatDataType); + } } diff --git a/src/test/java/org/apache/rocketmq/flink/legacy/RocketMQSourceTest.java b/src/test/java/org/apache/rocketmq/flink/legacy/RocketMQSourceTest.java index 9e78190..30e28df 100644 --- a/src/test/java/org/apache/rocketmq/flink/legacy/RocketMQSourceTest.java +++ b/src/test/java/org/apache/rocketmq/flink/legacy/RocketMQSourceTest.java @@ -18,7 +18,6 @@ package org.apache.rocketmq.flink.legacy; -import java.util.Map; import org.apache.rocketmq.client.consumer.DefaultLitePullConsumer; import org.apache.rocketmq.client.consumer.PullResult; import org.apache.rocketmq.client.consumer.PullStatus; @@ -36,6 +35,7 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Properties; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; diff --git a/src/test/java/org/apache/rocketmq/flink/source/reader/deserializer/RocketMQRowDeserializationSchemaTest.java b/src/test/java/org/apache/rocketmq/flink/source/reader/deserializer/RocketMQRowDeserializationSchemaTest.java index a904b04..37cc038 100644 --- a/src/test/java/org/apache/rocketmq/flink/source/reader/deserializer/RocketMQRowDeserializationSchemaTest.java +++ b/src/test/java/org/apache/rocketmq/flink/source/reader/deserializer/RocketMQRowDeserializationSchemaTest.java @@ -21,11 +21,13 @@ import org.apache.rocketmq.common.message.MessageExt; import org.apache.flink.api.common.serialization.DeserializationSchema.InitializationContext; +import org.apache.flink.formats.json.JsonRowDataDeserializationSchema; import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.api.TableSchema; import org.apache.flink.table.data.RowData; import org.apache.flink.util.Collector; +import org.junit.Before; import org.junit.Test; import org.powermock.reflect.Whitebox; @@ -33,6 +35,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.List; +import java.util.Map; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.doNothing; @@ -41,9 +44,11 @@ /** Test for {@link RocketMQRowDeserializationSchema}. */ public class RocketMQRowDeserializationSchemaTest { - @Test - public void testDeserialize() { - TableSchema tableSchema = + private TableSchema tableSchema; + + @Before + public void setup() { + tableSchema = new TableSchema.Builder() .field("int", DataTypes.INT()) .field("varchar", DataTypes.VARCHAR(100)) @@ -59,8 +64,13 @@ public void testDeserialize() { .field("time", DataTypes.TIME()) .field("timestamp", DataTypes.TIMESTAMP()) .build(); + } + + @Test + public void testDeserialize() throws Exception { RocketMQRowDeserializationSchema recordDeserializer = - new RocketMQRowDeserializationSchema(tableSchema, new HashMap<>(), false, null); + new RocketMQRowDeserializationSchema( + tableSchema, null, null, new HashMap<>(), false, null); RowDeserializationSchema sourceDeserializer = mock(RowDeserializationSchema.class); InitializationContext initializationContext = mock(InitializationContext.class); doNothing().when(sourceDeserializer).open(initializationContext); @@ -138,4 +148,95 @@ public void testDeserialize() { String.valueOf(thirdMsg.getQueueOffset()), recordDeserializer.getBytesMessages().get(2).getProperty("__queue_offset__")); } + + @Test + public void testJsonDeserialize() throws Exception { + Map props = new HashMap<>(); + props.put("consumergroup", "please_rename_unique_group_name"); + props.put("nameserveraddress", "10.211.55.5:9876"); + props.put("schema.0.data-type", "VARCHAR(2147483647)"); + props.put("connector", "rocketmq"); + props.put("schema.0.name", "id"); + + JsonRowDataDeserializationSchema valueDeserializer = + mock(JsonRowDataDeserializationSchema.class); + RocketMQRowDeserializationSchema recordDeserializer = + new RocketMQRowDeserializationSchema( + tableSchema, null, valueDeserializer, props, false, null); + + InitializationContext initializationContext = mock(InitializationContext.class); + doNothing().when(valueDeserializer).open(initializationContext); + recordDeserializer.open(initializationContext); + MessageExt firstMsg = + new MessageExt( + 1, + System.currentTimeMillis(), + InetSocketAddress.createUnresolved("localhost", 8080), + System.currentTimeMillis(), + InetSocketAddress.createUnresolved("localhost", 8088), + "184019387"); + firstMsg.setBody("test_deserializer_raw_messages_1".getBytes()); + MessageExt secondMsg = + new MessageExt( + 1, + System.currentTimeMillis(), + InetSocketAddress.createUnresolved("localhost", 8081), + System.currentTimeMillis(), + InetSocketAddress.createUnresolved("localhost", 8087), + "284019387"); + secondMsg.setBody("test_deserializer_raw_messages_2".getBytes()); + MessageExt thirdMsg = + new MessageExt( + 1, + System.currentTimeMillis(), + InetSocketAddress.createUnresolved("localhost", 8082), + System.currentTimeMillis(), + InetSocketAddress.createUnresolved("localhost", 8086), + "384019387"); + thirdMsg.setBody("test_deserializer_raw_messages_3".getBytes()); + List messages = Arrays.asList(firstMsg, secondMsg, thirdMsg); + Collector collector = mock(Collector.class); + recordDeserializer.deserialize(messages, collector); + + assertEquals(3, recordDeserializer.getBytesMessages().size()); + assertEquals(firstMsg.getBody(), recordDeserializer.getBytesMessages().get(0).getData()); + assertEquals( + String.valueOf(firstMsg.getStoreTimestamp()), + recordDeserializer.getBytesMessages().get(0).getProperty("__store_timestamp__")); + assertEquals( + String.valueOf(firstMsg.getBornTimestamp()), + recordDeserializer.getBytesMessages().get(0).getProperty("__born_timestamp__")); + assertEquals( + String.valueOf(firstMsg.getQueueId()), + recordDeserializer.getBytesMessages().get(0).getProperty("__queue_id__")); + assertEquals( + String.valueOf(firstMsg.getQueueOffset()), + recordDeserializer.getBytesMessages().get(0).getProperty("__queue_offset__")); + assertEquals(secondMsg.getBody(), recordDeserializer.getBytesMessages().get(1).getData()); + assertEquals( + String.valueOf(secondMsg.getStoreTimestamp()), + recordDeserializer.getBytesMessages().get(1).getProperty("__store_timestamp__")); + assertEquals( + String.valueOf(secondMsg.getBornTimestamp()), + recordDeserializer.getBytesMessages().get(1).getProperty("__born_timestamp__")); + assertEquals( + String.valueOf(secondMsg.getQueueId()), + recordDeserializer.getBytesMessages().get(1).getProperty("__queue_id__")); + assertEquals( + String.valueOf(secondMsg.getQueueOffset()), + recordDeserializer.getBytesMessages().get(1).getProperty("__queue_offset__")); + assertEquals(thirdMsg.getBody(), recordDeserializer.getBytesMessages().get(2).getData()); + assertEquals( + String.valueOf(thirdMsg.getStoreTimestamp()), + recordDeserializer.getBytesMessages().get(2).getProperty("__store_timestamp__")); + assertEquals( + String.valueOf(thirdMsg.getBornTimestamp()), + recordDeserializer.getBytesMessages().get(2).getProperty("__born_timestamp__")); + assertEquals( + String.valueOf(thirdMsg.getQueueId()), + recordDeserializer.getBytesMessages().get(2).getProperty("__queue_id__")); + assertEquals( + String.valueOf(thirdMsg.getQueueOffset()), + recordDeserializer.getBytesMessages().get(2).getProperty("__queue_offset__")); + } } diff --git a/src/test/java/org/apache/rocketmq/flink/source/table/RocketMQDynamicTableSourceFactoryTest.java b/src/test/java/org/apache/rocketmq/flink/source/table/RocketMQDynamicTableSourceFactoryTest.java index 358e816..5edf02f 100644 --- a/src/test/java/org/apache/rocketmq/flink/source/table/RocketMQDynamicTableSourceFactoryTest.java +++ b/src/test/java/org/apache/rocketmq/flink/source/table/RocketMQDynamicTableSourceFactoryTest.java @@ -56,6 +56,8 @@ public class RocketMQDynamicTableSourceFactoryTest { private static final String CONSUMER_GROUP = "test_consumer"; private static final String NAME_SERVER_ADDRESS = "127.0.0.1:9876"; + private static final String FORMAT_JSON = "json"; + @Test public void testRocketMQDynamicTableSourceWithLegalOption() { final Map options = new HashMap<>(); @@ -68,6 +70,40 @@ public void testRocketMQDynamicTableSourceWithLegalOption() { assertEquals(RocketMQScanTableSource.class.getName(), tableSource.asSummaryString()); } + @Test + public void testRocketMQDynamicTableSourceWithFormatOption() { + final Map options = new HashMap<>(); + options.put("connector", IDENTIFIER); + options.put(RocketMQOptions.TOPIC.key(), TOPIC); + options.put(RocketMQOptions.CONSUMER_GROUP.key(), CONSUMER_GROUP); + options.put(RocketMQOptions.NAME_SERVER_ADDRESS.key(), NAME_SERVER_ADDRESS); + + options.put(FactoryUtil.FORMAT.key(), FORMAT_JSON); + final DynamicTableSource tableSource = createTableSource(options); + + assertTrue(tableSource instanceof RocketMQScanTableSource); + assertEquals(RocketMQScanTableSource.class.getName(), tableSource.asSummaryString()); + } + + @Test + public void testRocketMQDynamicTableSourceWithJsonOption() { + final Map options = new HashMap<>(); + options.put("connector", IDENTIFIER); + options.put(RocketMQOptions.TOPIC.key(), TOPIC); + options.put(RocketMQOptions.CONSUMER_GROUP.key(), CONSUMER_GROUP); + options.put(RocketMQOptions.NAME_SERVER_ADDRESS.key(), NAME_SERVER_ADDRESS); + options.put(FactoryUtil.FORMAT.key(), FORMAT_JSON); + + // json props + options.put("json.fail-on-missing-field", "false"); + options.put("json.ignore-parse-errors", "true"); + options.put("json.map-null-key.mode", "FAIL"); + final DynamicTableSource tableSource = createTableSource(options); + + assertTrue(tableSource instanceof RocketMQScanTableSource); + assertEquals(RocketMQScanTableSource.class.getName(), tableSource.asSummaryString()); + } + @Test(expected = ValidationException.class) public void testRocketMQDynamicTableSourceWithoutRequiredOption() { final Map options = new HashMap<>();