Skip to content

Commit

Permalink
feat: make UDAFs configurable and remove limit on COLLECT_LIST/SET (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
agavra committed Jan 13, 2021
1 parent e2cd29d commit 63ae169
Show file tree
Hide file tree
Showing 33 changed files with 1,088 additions and 187 deletions.
Expand Up @@ -15,9 +15,11 @@

package io.confluent.ksql.function;

import com.google.common.collect.ImmutableMap;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
Expand All @@ -28,18 +30,41 @@
*/
public class AggregateFunctionInitArguments {

public static final AggregateFunctionInitArguments EMPTY_ARGS =
new AggregateFunctionInitArguments();

private final int udafIndex;
private final List<Object> initArgs;
private final Map<String, ?> config;

public static final AggregateFunctionInitArguments EMPTY_ARGS =
new AggregateFunctionInitArguments();
/**
* This method should only be used for legacy "built-in" UDAF
* implementations that implement AggregateFunctionFactory directly
* such as TopKAggregateFuncitonFactory. Otherwise, the config will
* not be properly passed through to the aggregate function.
*/
public AggregateFunctionInitArguments(
final int index,
final Object... initArgs
) {
this(index, ImmutableMap.of(/* not a configurable function */), Arrays.asList(initArgs));
}

public AggregateFunctionInitArguments(final int index, final Object... initArgs) {
this(index, Arrays.asList(initArgs));
public AggregateFunctionInitArguments(
final int index,
final Map<String, ?> config,
final Object... initArgs
) {
this(index, config, Arrays.asList(initArgs));
}

public AggregateFunctionInitArguments(final int index, final List<Object> initArgs) {
public AggregateFunctionInitArguments(
final int index,
final Map<String, ?> config,
final List<Object> initArgs
) {
this.udafIndex = index;
this.config = ImmutableMap.copyOf(Objects.requireNonNull(config, "config"));
this.initArgs = Objects.requireNonNull(initArgs);

if (index < 0) {
Expand All @@ -49,6 +74,7 @@ public AggregateFunctionInitArguments(final int index, final List<Object> initAr

private AggregateFunctionInitArguments() {
this.udafIndex = 0;
this.config = ImmutableMap.of();
this.initArgs = Collections.emptyList();
}

Expand All @@ -63,4 +89,8 @@ public Object arg(final int i) {
public List<Object> args() {
return initArgs;
}

public Map<String, ?> config() {
return config;
}
}
Expand Up @@ -509,6 +509,10 @@ String getName() {
public static final Set<String> SSL_CONFIG_NAMES = sslConfigNames();
public static final Set<String> STREAM_TOPIC_CONFIG_NAMES = streamTopicConfigNames();

public static KsqlConfig empty() {
return new KsqlConfig(ImmutableMap.of());
}

private static ConfigDef configDef(final ConfigGeneration generation) {
return generation == ConfigGeneration.CURRENT ? CURRENT_DEF : LEGACY_DEF;
}
Expand Down
Expand Up @@ -100,7 +100,7 @@ private List<SqlType> buildAllParams(
allParams.add(primitiveType);
} catch (final Exception e) {
throw new KsqlFunctionException("Only primitive init arguments are supported by UDAF "
+ getName() + ", but got " + arg);
+ getName() + ", but got " + arg, e);
}
}

Expand Down
Expand Up @@ -28,6 +28,7 @@
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.kafka.common.Configurable;
import org.apache.kafka.common.metrics.Metrics;

class UdafFactoryInvoker implements FunctionSignature {
Expand Down Expand Up @@ -78,6 +79,11 @@ KsqlAggregateFunction createFunction(final AggregateFunctionInitArguments initAr
final Object[] factoryArgs = initArgs.args().toArray();
try {
final Udaf udaf = (Udaf)method.invoke(null, factoryArgs);

if (udaf instanceof Configurable) {
((Configurable) udaf).configure(initArgs.config());
}

final KsqlAggregateFunction function;
if (TableUdaf.class.isAssignableFrom(method.getReturnType())) {
function = new UdafTableAggregateFunction(
Expand Down
Expand Up @@ -15,98 +15,109 @@

package io.confluent.ksql.function.udaf.array;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import io.confluent.ksql.function.udaf.TableUdaf;
import io.confluent.ksql.function.udaf.UdafDescription;
import io.confluent.ksql.function.udaf.UdafFactory;
import io.confluent.ksql.util.KsqlConstants;
import java.util.List;
import java.util.Map;
import org.apache.kafka.common.Configurable;

@UdafDescription(
name = "collect_list",
description = "Gather all of the values from an input grouping into a single Array field."
+ "\nAlthough this aggregate works on both Stream and Table inputs, the order of entries"
+ " in the result array is not guaranteed when working on Table input data."
+ "\nThis version limits the size of the resultant Array to 1000 entries, beyond which"
+ " any further values will be silently ignored.",
+ "\nYou may limit the size of the resultant Array to N entries, beyond which"
+ " any further values will be silently ignored, by setting the"
+ " ksql.functions.collect_list.limit configuration to N.",
author = KsqlConstants.CONFLUENT_AUTHOR
)
public final class CollectListUdaf {

@VisibleForTesting
static final int LIMIT = 1000;
public static final String LIMIT_CONFIG = "ksql.functions.collect_list.limit";

private CollectListUdaf() {
// just to make the checkstyle happy
}

private static <T> TableUdaf<T, List<T>, List<T>> listCollector() {
return new TableUdaf<T, List<T>, List<T>>() {

@Override
public List<T> initialize() {
return Lists.newArrayList();
}

@Override
public List<T> aggregate(final T thisValue, final List<T> aggregate) {
if (aggregate.size() < LIMIT) {
aggregate.add(thisValue);
}
return aggregate;
}

@Override
public List<T> merge(final List<T> aggOne, final List<T> aggTwo) {
final int remainingCapacity = LIMIT - aggOne.size();
aggOne.addAll(aggTwo.subList(0, Math.min(remainingCapacity, aggTwo.size())));
return aggOne;
}

@Override
public List<T> map(final List<T> agg) {
return agg;
}

@Override
public List<T> undo(final T valueToUndo, final List<T> aggregateValue) {
// A more ideal solution would remove the value which corresponded to the original insertion
// but keeping track of that is more complex so we just remove the last value for now.
final int lastIndex = aggregateValue.lastIndexOf(valueToUndo);
// If we cannot find the value, that means that we hit the limit and never inserted it, so
// just return.
if (lastIndex < 0) {
return aggregateValue;
}
aggregateValue.remove(lastIndex);
return aggregateValue;
}
};
}

@UdafFactory(description = "collect values of a Bigint field into a single Array")
public static TableUdaf<Long, List<Long>, List<Long>> createCollectListLong() {
return listCollector();
return new Collect<>();
}

@UdafFactory(description = "collect values of an Integer field into a single Array")
public static TableUdaf<Integer, List<Integer>, List<Integer>> createCollectListInt() {
return listCollector();
return new Collect<>();
}

@UdafFactory(description = "collect values of a Double field into a single Array")
public static TableUdaf<Double, List<Double>, List<Double>> createCollectListDouble() {
return listCollector();
return new Collect<>();
}

@UdafFactory(description = "collect values of a String/Varchar field into a single Array")
public static TableUdaf<String, List<String>, List<String>> createCollectListString() {
return listCollector();
return new Collect<>();
}

@UdafFactory(description = "collect values of a Boolean field into a single Array")
public static TableUdaf<Boolean, List<Boolean>, List<Boolean>> createCollectListBool() {
return listCollector();
return new Collect<>();
}

private static final class Collect<T> implements TableUdaf<T, List<T>, List<T>>, Configurable {

private int limit = Integer.MAX_VALUE;

@Override
public void configure(final Map<String, ?> map) {
final Object limit = map.get(LIMIT_CONFIG);
this.limit = (limit == null) ? this.limit : ((Number) limit).intValue();

if (this.limit < 0) {
this.limit = Integer.MAX_VALUE;
}
}

@Override
public List<T> initialize() {
return Lists.newArrayList();
}

@Override
public List<T> aggregate(final T thisValue, final List<T> aggregate) {
if (aggregate.size() < limit) {
aggregate.add(thisValue);
}
return aggregate;
}

@Override
public List<T> merge(final List<T> aggOne, final List<T> aggTwo) {
final int remainingCapacity = limit - aggOne.size();
aggOne.addAll(aggTwo.subList(0, Math.min(remainingCapacity, aggTwo.size())));
return aggOne;
}

@Override
public List<T> map(final List<T> agg) {
return agg;
}

@Override
public List<T> undo(final T valueToUndo, final List<T> aggregateValue) {
// A more ideal solution would remove the value which corresponded to the original insertion
// but keeping track of that is more complex so we just remove the last value for now.
final int lastIndex = aggregateValue.lastIndexOf(valueToUndo);
// If we cannot find the value, that means that we hit the limit and never inserted it, so
// just return.
if (lastIndex < 0) {
return aggregateValue;
}
aggregateValue.remove(lastIndex);
return aggregateValue;
}
}
}

0 comments on commit 63ae169

Please sign in to comment.