Skip to content

Commit

Permalink
fix: support UDAFs with different intermediate schema (#3412)
Browse files Browse the repository at this point in the history
  • Loading branch information
big-andy-coates authored Sep 27, 2019
1 parent 3eb5327 commit 70e10e9
Show file tree
Hide file tree
Showing 22 changed files with 307 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
package io.confluent.ksql.configdef;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.util.function.Function;
import org.apache.kafka.common.config.ConfigDef.Validator;
Expand All @@ -27,9 +29,6 @@
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;

import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@RunWith(MockitoJUnitRunner.class)
public class ConfigValidatorsTest {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright 2019 Confluent Inc.
*
* Licensed under the Confluent Community License (the "License"); you may not use
* this file except in compliance with the License. You may obtain a copy of the
* License at
*
* http://www.confluent.io/confluent-community-license
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/

package io.confluent.ksql.materialization;

import static java.util.Objects.requireNonNull;

import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.Immutable;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import java.util.List;


@Immutable
public final class AggregatesInfo {

private final int startingColumnIndex;
private final List<FunctionCall> aggregateFunctions;
private final LogicalSchema schema;

/**
* @param startingColumnIndex column index of first aggregate function.
* @param aggregateFunctions the map of column index to aggregate function.
* @param schema the schema required by the aggregators.
* @return the immutable instance.
*/
public static AggregatesInfo of(
final int startingColumnIndex,
final List<FunctionCall> aggregateFunctions,
final LogicalSchema schema
) {
return new AggregatesInfo(startingColumnIndex, aggregateFunctions, schema);
}

private AggregatesInfo(
final int startingColumnIndex,
final List<FunctionCall> aggregateFunctions,
final LogicalSchema prepareSchema
) {
this.startingColumnIndex = startingColumnIndex;
this.aggregateFunctions = ImmutableList
.copyOf(requireNonNull(aggregateFunctions, "aggregateFunctions"));
this.schema = requireNonNull(prepareSchema, "prepareSchema");
}

public int startingColumnIndex() {
return startingColumnIndex;
}

public List<FunctionCall> aggregateFunctions() {
return aggregateFunctions;
}

public LogicalSchema schema() {
return schema;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,48 @@
/**
* {@link Materialization} implementation responsible for handling HAVING and SELECT clauses.
*
* <p>Underlying {@link Materialization} store data in a different schema and have not had any
* HAVING predicate applied. Mapping from the aggregate store schema to the table's schema and
* applying any HAVING predicate is handled by this class.
* <p>Underlying {@link Materialization} store data is not the same as the table it servers.
* Specifically, it has not had:
* <ol>
* <li>
* The {@link io.confluent.ksql.function.udaf.Udaf#map} call applied to convert intermediate
* aggregate types on output types
* </li>
* <li>
* Any HAVING predicate applied.
* </li>
* <li>
* The select value mapper applied to convert from the internal schema to the table's scheam.
* </li>
* </ol>
*
* <p>This class is responsible for this for now. Long term, these should be handled by physical
* plan steps.
*/
class KsqlMaterialization implements Materialization {

private final Materialization inner;
private final Function<GenericRow, GenericRow> aggregateTransform;
private final Predicate<Struct, GenericRow> havingPredicate;
private final Function<GenericRow, GenericRow> storeToTableTransform;
private final LogicalSchema schema;

/**
* @param inner the inner materialization, e.g. a KS specific one
* @param aggregateTransform converts from aggregates from intermediate to output types.
* @param havingPredicate the predicate for handling HAVING clauses.
* @param storeToTableTransform maps from internal to table schema.
* @param schema the schema of the materialized table.
*/
KsqlMaterialization(
final Materialization inner,
final Function<GenericRow, GenericRow> aggregateTransform,
final Predicate<Struct, GenericRow> havingPredicate,
final Function<GenericRow, GenericRow> storeToTableTransform,
final LogicalSchema schema
) {
this.inner = requireNonNull(inner, "table");
this.aggregateTransform = requireNonNull(aggregateTransform, "aggregateTransform");
this.havingPredicate = requireNonNull(havingPredicate, "havingPredicate");
this.storeToTableTransform = requireNonNull(storeToTableTransform, "storeToTableTransform");
this.schema = requireNonNull(schema, "schema");
Expand Down Expand Up @@ -86,6 +110,9 @@ private Optional<GenericRow> filterAndTransform(
final GenericRow value
) {
return Optional.of(value)
// Call Udaf.map() to convert the internal representation stored in the state store into
// the output type of the aggregator
.map(aggregateTransform)
// HAVING predicate from source table query that has not already been applied to the
// store, so must be applied to any result from the store.
.filter(v -> havingPredicate.test(key, v))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.sqlpredicate.SqlPredicate;
import io.confluent.ksql.execution.streams.AggregateParams;
import io.confluent.ksql.execution.streams.SelectValueMapperFactory;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.logging.processing.ProcessingLogContext;
Expand All @@ -47,8 +48,9 @@ public final class KsqlMaterializationFactory {
private final KsqlConfig ksqlConfig;
private final FunctionRegistry functionRegistry;
private final ProcessingLogContext processingLogContext;
private final AggregateMapperFactory aggregateMapperFactory;
private final SqlPredicateFactory sqlPredicateFactory;
private final ValueMapperFactory valueMapperFactory;
private final SelectMapperFactory selectMapperFactory;
private final MaterializationFactory materializationFactory;

public KsqlMaterializationFactory(
Expand All @@ -60,6 +62,7 @@ public KsqlMaterializationFactory(
ksqlConfig,
functionRegistry,
processingLogContext,
defaultAggregateMapperFactory(),
SqlPredicate::new,
defaultValueMapperFactory(),
KsqlMaterialization::new
Expand All @@ -71,15 +74,17 @@ public KsqlMaterializationFactory(
final KsqlConfig ksqlConfig,
final FunctionRegistry functionRegistry,
final ProcessingLogContext processingLogContext,
final AggregateMapperFactory aggregateMapperFactory,
final SqlPredicateFactory sqlPredicateFactory,
final ValueMapperFactory valueMapperFactory,
final SelectMapperFactory selectMapperFactory,
final MaterializationFactory materializationFactory
) {
this.ksqlConfig = requireNonNull(ksqlConfig, "ksqlConfig");
this.functionRegistry = requireNonNull(functionRegistry, "functionRegistry");
this.processingLogContext = requireNonNull(processingLogContext, "processingLogContext");
this.aggregateMapperFactory = requireNonNull(aggregateMapperFactory, "aggregateMapperFactory");
this.sqlPredicateFactory = requireNonNull(sqlPredicateFactory, "sqlPredicateFactory");
this.valueMapperFactory = requireNonNull(valueMapperFactory, "valueMapperFactory");
this.selectMapperFactory = requireNonNull(selectMapperFactory, "selectMapperFactory");
this.materializationFactory = requireNonNull(materializationFactory, "materializationFactory");
}

Expand All @@ -88,6 +93,9 @@ public Materialization create(
final MaterializationInfo info,
final QueryContext.Stacker contextStacker
) {
final Function<GenericRow, GenericRow> aggregateMapper =
bakeAggregateMapper(info);

final Predicate<Struct, GenericRow> havingPredicate =
bakeHavingExpression(info, contextStacker);

Expand All @@ -96,12 +104,22 @@ public Materialization create(

return materializationFactory.create(
delegate,
aggregateMapper,
havingPredicate,
valueMapper,
info.tableSchema()
);
}

private Function<GenericRow, GenericRow> bakeAggregateMapper(
final MaterializationInfo info
) {
return aggregateMapperFactory.create(
info.aggregatesInfo(),
functionRegistry
);
}

private Predicate<Struct, GenericRow> bakeHavingExpression(
final MaterializationInfo info,
final QueryContext.Stacker contextStacker
Expand Down Expand Up @@ -135,7 +153,7 @@ private Function<GenericRow, GenericRow> bakeStoreSelects(
QueryLoggerUtil.queryLoggerName(contextStacker.push(PROJECT_OP_NAME).getQueryContext())
);

return valueMapperFactory.create(
return selectMapperFactory.create(
info.tableSelects(),
info.aggregationSchema(),
ksqlConfig,
Expand All @@ -144,7 +162,19 @@ private Function<GenericRow, GenericRow> bakeStoreSelects(
);
}

private static ValueMapperFactory defaultValueMapperFactory() {
private static AggregateMapperFactory defaultAggregateMapperFactory() {
return (info, functionRegistry) ->
new AggregateParams(
info.schema(),
info.startingColumnIndex(),
functionRegistry,
info.aggregateFunctions()
)
.getAggregator()
.getResultMapper()::apply;
}

private static SelectMapperFactory defaultValueMapperFactory() {
return (selectExpressions, sourceSchema, ksqlConfig, functionRegistry, processingLogger) ->
SelectValueMapperFactory.create(
selectExpressions,
Expand All @@ -155,6 +185,14 @@ private static ValueMapperFactory defaultValueMapperFactory() {
)::apply;
}

interface AggregateMapperFactory {

Function<GenericRow, GenericRow> create(
AggregatesInfo info,
FunctionRegistry functionRegistry
);
}

interface SqlPredicateFactory {

SqlPredicate create(
Expand All @@ -166,7 +204,7 @@ SqlPredicate create(
);
}

interface ValueMapperFactory {
interface SelectMapperFactory {

Function<GenericRow, GenericRow> create(
List<SelectExpression> selectExpressions,
Expand All @@ -181,6 +219,7 @@ interface MaterializationFactory {

KsqlMaterialization create(
Materialization inner,
Function<GenericRow, GenericRow> aggregateTransform,
Predicate<Struct, GenericRow> havingPredicate,
Function<GenericRow, GenericRow> storeToTableTransform,
LogicalSchema schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
public final class MaterializationInfo {

private final String stateStoreName;
private final AggregatesInfo aggregatesInfo;
private final LogicalSchema aggregationSchema;
private final Optional<Expression> havingExpression;
private final LogicalSchema tableSchema;
Expand All @@ -41,22 +42,25 @@ public final class MaterializationInfo {
* Create instance.
*
* @param stateStoreName the name of the state store
* @param stateStoreSchema the schema of the state store
* @param aggregatesInfo info about the aggregate functions used.
* @param aggregationSchema the schema of the state store
* @param havingExpression optional HAVING expression that should be apply to any store result.
* @param tableSchema the schema of the table.
* @param tableSelects SELECT expressions to convert state store schema to table schema.
* @return instance.
*/
public static MaterializationInfo of(
final String stateStoreName,
final LogicalSchema stateStoreSchema,
final AggregatesInfo aggregatesInfo,
final LogicalSchema aggregationSchema,
final Optional<Expression> havingExpression,
final LogicalSchema tableSchema,
final List<SelectExpression> tableSelects
) {
return new MaterializationInfo(
stateStoreName,
stateStoreSchema,
aggregatesInfo,
aggregationSchema,
havingExpression,
tableSchema,
tableSelects
Expand All @@ -67,6 +71,10 @@ public String stateStoreName() {
return stateStoreName;
}

public AggregatesInfo aggregatesInfo() {
return aggregatesInfo;
}

public LogicalSchema aggregationSchema() {
return aggregationSchema;
}
Expand All @@ -85,12 +93,14 @@ public List<SelectExpression> tableSelects() {

private MaterializationInfo(
final String stateStoreName,
final AggregatesInfo aggregatesInfo,
final LogicalSchema aggregationSchema,
final Optional<Expression> havingExpression,
final LogicalSchema tableSchema,
final List<SelectExpression> tableSelects
) {
this.stateStoreName = requireNonNull(stateStoreName, "stateStoreName");
this.aggregatesInfo = requireNonNull(aggregatesInfo, "aggregatesInfo");
this.aggregationSchema = requireNonNull(aggregationSchema, "aggregationSchema");
this.havingExpression = requireNonNull(havingExpression, "havingExpression");
this.tableSchema = requireNonNull(tableSchema, "tableSchema");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.function.KsqlAggregateFunction;
import io.confluent.ksql.materialization.AggregatesInfo;
import io.confluent.ksql.materialization.MaterializationInfo;
import io.confluent.ksql.metastore.model.KeyField;
import io.confluent.ksql.name.ColumnName;
Expand Down Expand Up @@ -248,6 +249,8 @@ public SchemaKStream<?> buildStream(final KsqlQueryBuilder builder) {

final QueryContext.Stacker aggregationContext = contextStacker.push(AGGREGATION_OP_NAME);

// This is the schema post any {@link Udaf#map} steps to reduce intermediate aggregate state
// to the final output state
final LogicalSchema outputSchema = buildLogicalSchema(
prepareSchema,
functionsWithInternalIdentifiers,
Expand Down Expand Up @@ -280,8 +283,15 @@ public SchemaKStream<?> buildStream(final KsqlQueryBuilder builder) {
final List<SelectExpression> finalSelects = internalSchema
.updateFinalSelectExpressions(getFinalSelectExpressions());

final AggregatesInfo aggregatesInfo = AggregatesInfo.of(
requiredColumns.size(),
functionsWithInternalIdentifiers,
prepareSchema
);

materializationInfo = Optional.of(MaterializationInfo.of(
AGGREGATE_STATE_STORE_NAME,
aggregatesInfo,
outputSchema,
havingExpression,
schema,
Expand Down
Loading

0 comments on commit 70e10e9

Please sign in to comment.