Skip to content

Commit

Permalink
feat: Support Object as the type parameter for a UDAF variadic colu…
Browse files Browse the repository at this point in the history
…mn argument (#9481)

* feat: support Object as type param for UDAF variadic col args

* test: add unit tests for heterogeneous col var args

* test: add QTTs for heterogeneous col varargs

* test: add historical plans

* Update non-var obj UDAF param error message per code review

Co-authored-by: Zara Lim <jzaralim@gmail.com>

* test: update error messages in tests

* refactor: move obj var arg param type to static const

Co-authored-by: Zara Lim <jzaralim@gmail.com>
  • Loading branch information
reneesoika and jzaralim committed Aug 24, 2022
1 parent 6a6ccd3 commit 9444fe1
Show file tree
Hide file tree
Showing 30 changed files with 4,501 additions and 43 deletions.
22 changes: 22 additions & 0 deletions ksqldb-cli/src/test/java/io/confluent/ksql/cli/CliTest.java
Expand Up @@ -1109,6 +1109,28 @@ public void shouldDescribeVariadicAggregateFunction() {
assertThat(output, containsString(expectedVariant));
}

@Test
public void shouldDescribeVariadicObjectAggregateFunction() {
final String expectedSummary =
"Name : OBJ_COL_ARG\n"
+ "Author : Confluent\n"
+ "Overview : Returns an array of rows where all the given columns are non-null.\n"
+ "Type : AGGREGATE\n"
+ "Jar : internal\n"
+ "Variations : \n";

final String expectedVariant =
"\tVariation : OBJ_COL_ARG(val1 INT, val2 ANY[])\n"
+ "\tReturns : ARRAY<INT>\n"
+ "\tDescription : Testing factory";

localCli.handleLine("describe function obj_col_arg;");

final String output = terminal.getOutputString();
assertThat(output, containsString(expectedSummary));
assertThat(output, containsString(expectedVariant));
}

@Test
public void shouldDescribeTableFunction() {
final String expectedOutput =
Expand Down
Expand Up @@ -59,7 +59,10 @@
* have variable arguments, return the method with the more non-variable
* arguments.</li>
* <li>If two methods exist that match given the above rules, return the
* method with fewer generic arguments.</li>
* method with fewer generic arguments, including any
* {@code Object...} arguments.</li>
* <li>If two methods exist that match given the above rules, return the
* method without an {@code Object...} argument.</li>
* <li>If two methods exist that match given the above rules, return the
* method with the variadic argument in the later position.</li>
* <li>If two methods exist that match given the above rules, the function
Expand Down Expand Up @@ -102,6 +105,7 @@ public class UdfIndex<T extends FunctionSignature> {
// Gist: https://gist.github.com/reneesoika/2ec934940c98dad6b0f68c89769fbe42

private static final Logger LOG = LoggerFactory.getLogger(UdfIndex.class);
private static final ParamType OBJ_VAR_ARG = ArrayType.of(ParamTypes.ANY);

private final String udfName;
private final Node root = new Node();
Expand Down Expand Up @@ -494,15 +498,19 @@ private static <T extends FunctionSignature> String formatAvailableSignatures(fi
}

private final class Node {

@VisibleForTesting
private final Comparator<T> compareFunctions =
Comparator.nullsFirst(
Comparator
.<T, Integer>comparing(fun -> fun.isVariadic() ? 0 : 1)
.thenComparing(fun -> fun.parameters().size())
.thenComparing(fun -> -countGenerics(fun))
.thenComparing(this::indexOfVariadic)
.thenComparing(fun ->
fun.parameters().stream().anyMatch(
(param) -> param.equals(OBJ_VAR_ARG)
)
? 0 : 1
).thenComparing(this::indexOfVariadic)
);

private final Map<Pair<Parameter, Parameter>, Node> children;
Expand Down Expand Up @@ -545,7 +553,8 @@ int compare(final Node other) {

private int countGenerics(final T function) {
return function.parameters().stream()
.filter(GenericsUtil::hasGenerics)
.filter((param) -> GenericsUtil.hasGenerics(param)
|| param.equals(OBJ_VAR_ARG))
.mapToInt(p -> 1)
.sum();
}
Expand Down
@@ -0,0 +1,34 @@
/*
* Copyright 2022 Confluent Inc.
*
* Licensed under the Confluent Community License; you may not use this file
* except in compliance with the License. You may obtain a copy of the License at
*
* http://www.confluent.io/confluent-community-license
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/

package io.confluent.ksql.function.types;

public final class AnyType extends ObjectType {
public static final AnyType INSTANCE = new AnyType();

@Override
public int hashCode() {
return 11;
}

@Override
public boolean equals(final Object obj) {
return obj instanceof AnyType;
}

@Override
public String toString() {
return "ANY";
}
}
Expand Up @@ -45,6 +45,7 @@ private ParamTypes() {
public static final TimestampType TIMESTAMP = TimestampType.INSTANCE;
public static final IntervalUnitType INTERVALUNIT = IntervalUnitType.INSTANCE;
public static final BytesType BYTES = BytesType.INSTANCE;
public static final AnyType ANY = AnyType.INSTANCE;

public static boolean areCompatible(final SqlArgument actual, final ParamType declared) {
return areCompatible(actual, declared, false);
Expand All @@ -59,6 +60,10 @@ public static boolean areCompatible(
) {
// CHECKSTYLE_RULES.ON: CyclomaticComplexity
// CHECKSTYLE_RULES.ON: NPathComplexity
if (declared instanceof AnyType) {
return true;
}

final Optional<SqlLambda> sqlLambdaOptional = argument.getSqlLambda();

if (sqlLambdaOptional.isPresent() && declared instanceof LambdaType) {
Expand Down
Expand Up @@ -43,6 +43,8 @@ public class UdfIndexTest {
private static final ParamType STRING_VARARGS = ArrayType.of(ParamTypes.STRING);

private static final ParamType INT_VARARGS = ArrayType.of(ParamTypes.INTEGER);

private static final ParamType OBJ_VARARGS = ArrayType.of(ParamTypes.ANY);
private static final ParamType STRING = ParamTypes.STRING;
private static final ParamType DECIMAL = ParamTypes.DECIMAL;
private static final ParamType INT = ParamTypes.INTEGER;
Expand Down Expand Up @@ -1295,6 +1297,42 @@ public void shouldFindFewerGenericsWithEarlierVariadic() {
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldFindFewerGenericsWithoutObjVariadic() {
// Given:
givenFunctions(
function(EXPECTED, 3, INT, GenericType.of("A"), INT, INT_VARARGS),
function(OTHER, 3, INT, GenericType.of("B"), INT, OBJ_VARARGS)
);

// When:
final KsqlScalarFunction fun = udfIndex.getFunction(ImmutableList.of(
SqlArgument.of(INTEGER), SqlArgument.of(INTEGER),
SqlArgument.of(INTEGER), SqlArgument.of(INTEGER)
));

// Then:
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldFindPreferGenericVariadicToObjVariadic() {
// Given:
givenFunctions(
function(EXPECTED, 3, INT, GenericType.of("A"), INT, ArrayType.of(GenericType.of("C"))),
function(OTHER, 3, INT, GenericType.of("B"), INT, OBJ_VARARGS)
);

// When:
final KsqlScalarFunction fun = udfIndex.getFunction(ImmutableList.of(
SqlArgument.of(INTEGER), SqlArgument.of(INTEGER),
SqlArgument.of(INTEGER), SqlArgument.of(INTEGER)
));

// Then:
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldThrowOnAmbiguousImplicitCastWithGenerics() {
// Given:
Expand Down Expand Up @@ -1417,6 +1455,26 @@ public void shouldChooseLaterVariadicWhenTwoVariadicsMatchReversedInsertionOrder
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldChooseLaterVariadicWhenTwoObjVariadicsMatch() {
// Given:
givenFunctions(
function(OTHER, 1, GenericType.of("A"), OBJ_VARARGS, STRING, DOUBLE),
function(EXPECTED, 2, GenericType.of("B"), INT, OBJ_VARARGS, DOUBLE)
);

// When:
final KsqlScalarFunction fun = udfIndex.getFunction(ImmutableList.of(
SqlArgument.of(SqlTypes.BIGINT),
SqlArgument.of(SqlTypes.INTEGER),
SqlArgument.of(SqlTypes.STRING),
SqlArgument.of(SqlTypes.DOUBLE))
);

// Then:
assertThat(fun.name(), equalTo(EXPECTED));
}

private void givenFunctions(final KsqlScalarFunction... functions) {
Arrays.stream(functions).forEach(udfIndex::addFunction);
}
Expand Down
Expand Up @@ -123,7 +123,7 @@ class UdafTypes {
throw new KsqlException("A UDAF and its factory can have at most one variadic argument");
}

variadicColIndex = indexOfVariadic(inputTypes);
variadicColIndex = indexOfType(inputTypes, VARIADIC_TYPE);
if (method.isVarArgs()) {
this.isVariadic = true;
} else if (isMultipleArgs && variadicColIndex > -1) {
Expand All @@ -140,15 +140,23 @@ class UdafTypes {
this.isVariadic = false;
}

// Disallow an object type outside a variadic column
final boolean hasMultipleObjectArgs = countTypes(inputTypes, Object.class) > 1;
final int indexOfFirstObj = indexOfType(inputTypes, Object.class);
final boolean objArgIsNotVariadic = indexOfFirstObj >= 0 && indexOfFirstObj != variadicColIndex;
if (hasMultipleObjectArgs || objArgIsNotVariadic) {
throw new KsqlException("The Object type can only be used as a variadic column argument");
}

this.aggregateType = type.getActualTypeArguments()[1];
this.outputType = type.getActualTypeArguments()[2];

this.literalParams = FunctionLoaderUtils
.createParameters(method, functionName.text(), sqlTypeParser);

validateTypes(inputTypes);
validateType(aggregateType);
validateType(outputType);
validateTypes(inputTypes, variadicColIndex);
validateType(aggregateType, false);
validateType(outputType, false);
}

List<ParameterInfo> getInputSchema(final String[] inSchemas) {
Expand All @@ -161,9 +169,11 @@ List<ParameterInfo> getInputSchema(final String[] inSchemas) {

validateStructAnnotation(inputType, schema, "paramSchema");

ParamType paramType = getSchemaFromType(inputType, schema);
final ParamType paramType;
if (paramIndex == variadicColIndex) {
paramType = ArrayType.of(paramType);
paramType = ArrayType.of(getSchemaFromType(inputType, schema, true));
} else {
paramType = getSchemaFromType(inputType, schema, false);
}

paramTypes.add(paramType);
Expand All @@ -189,12 +199,12 @@ List<ParameterInfo> getInputSchema(final String[] inSchemas) {

ParamType getAggregateSchema(final String aggSchema) {
validateStructAnnotation(aggregateType, aggSchema, "aggregateSchema");
return getSchemaFromType(aggregateType, aggSchema);
return getSchemaFromType(aggregateType, aggSchema, false);
}

ParamType getOutputSchema(final String outSchema) {
validateStructAnnotation(outputType, outSchema, "returnSchema");
return getSchemaFromType(outputType, outSchema);
return getSchemaFromType(outputType, outSchema, false);
}

boolean isVariadic() {
Expand All @@ -205,20 +215,22 @@ List<ParameterInfo> literalParams() {
return ImmutableList.copyOf(literalParams);
}

private void validateType(final Type t) {
if (!(t instanceof TypeVariable) && isUnsupportedType((Class<?>) getRawType(t))) {
private void validateType(final Type t, final boolean isVariadic) {
if (!(t instanceof TypeVariable)
&& isUnsupportedType((Class<?>) getRawType(t), isVariadic)) {

throw new KsqlException(String.format(invalidClassErrorMsg, t));
}
}

private void validateTypes(final Type[] types) {
for (Type type : types) {
validateType(type);
private void validateTypes(final Type[] types, final int variadicColIndex) {
for (int index = 0; index < types.length; index++) {
validateType(types[index], index == variadicColIndex);
}
}

private static long countVariadic(final Type[] types, final Method factory) {
long count = Arrays.stream(types).filter((type) -> getRawType(type) == VARIADIC_TYPE).count();
long count = countTypes(types, VARIADIC_TYPE);

/* If there is a variadic initial argument, include it in the total number of variadic
arguments. We need to include this because there can only be one variadic argument in
Expand All @@ -230,9 +242,13 @@ private static long countVariadic(final Type[] types, final Method factory) {
return count;
}

private static int indexOfVariadic(final Type[] types) {
private static long countTypes(final Type[] types, final Type matchType) {
return Arrays.stream(types).filter((type) -> getRawType(type).equals(matchType)).count();
}

private static int indexOfType(final Type[] types, final Type matchType) {
return IntStream.range(0, types.length)
.filter((index) -> getRawType(types[index]) == VARIADIC_TYPE)
.filter((index) -> getRawType(types[index]).equals(matchType))
.findFirst()
.orElse(-1);
}
Expand All @@ -248,11 +264,15 @@ private static void validateStructAnnotation(
}
}

private ParamType getSchemaFromType(final Type type, final String schema) {
return schema.isEmpty()
? UdfUtil.getSchemaFromType(type)
: SchemaConverters.sqlToFunctionConverter().toFunctionType(
sqlTypeParser.parse(schema).getSqlType());
private ParamType getSchemaFromType(final Type type, final String schema,
final boolean isVariadic) {
if (schema.isEmpty()) {
return isVariadic ? UdfUtil.getVarArgsSchemaFromType(type) : UdfUtil.getSchemaFromType(type);
}

return SchemaConverters.sqlToFunctionConverter().toFunctionType(
sqlTypeParser.parse(schema).getSqlType()
);
}

private static Type getRawType(final Type type) {
Expand All @@ -262,14 +282,16 @@ private static Type getRawType(final Type type) {
return type;
}

static boolean isUnsupportedType(final Class<?> type) {
@SuppressWarnings("checkstyle:BooleanExpressionComplexity")
private static boolean isUnsupportedType(final Class<?> type, final boolean allowObject) {
return !SUPPORTED_TYPES.contains(type)
&& (!type.isArray() || !SUPPORTED_TYPES.contains(type.getComponentType()))
&& SUPPORTED_TYPES.stream().noneMatch(supported -> supported.isAssignableFrom(type));
&& SUPPORTED_TYPES.stream().noneMatch(supported -> supported.isAssignableFrom(type))
&& (!allowObject || !type.equals(Object.class));
}

static void checkSupportedType(final Method method, final Class<?> type) {
if (UdafTypes.isUnsupportedType(type)) {
if (UdafTypes.isUnsupportedType(type, false)) {
throw new KsqlException(
String.format(
"Type %s is not supported by UDF methods. "
Expand Down

0 comments on commit 9444fe1

Please sign in to comment.