Skip to content

Commit

Permalink
fix: fixed nondeterminism in UdfIndex (#7719)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: Existing queries that relied on vague implicit casting will not be started after an upgrade, and new queries that rely on vague implicit casting will be rejected. For example, foo(INT, INT) will not be able to resolve against two underlying function signatures of foo(BIGINT, BIGINT) and foo(DOUBLE, DOUBLE). Calling a function whose only parameter is variadic with an explicit null will also result in the call being rejected as vague.
  • Loading branch information
Sullivan-Patrick committed Jun 28, 2021
1 parent e78a83b commit cd1a988
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 62 deletions.
108 changes: 75 additions & 33 deletions ksqldb-common/src/main/java/io/confluent/ksql/function/UdfIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@
* <li>If two methods exist that match given the above rules, and both
* have variable arguments, return the method with the more non-variable
* arguments.</li>
* <li>If two methods exist that match only null values, return the one
* that was added first.</li>
* <li>If two methods exist that match given the above rules, return the
* method with fewer generic arguments.</li>
* <li>If two methods exist that match given the above rules, the function
* call is ambiguous and an exception is thrown.</li>
* </ul>
*/
public class UdfIndex<T extends FunctionSignature> {
Expand Down Expand Up @@ -112,7 +114,6 @@ void addFunction(final T function) {
);
}

final int order = allFunctions.size();

Node curr = root;
Node parent = curr;
Expand All @@ -125,48 +126,61 @@ void addFunction(final T function) {
if (function.isVariadic()) {
// first add the function to the parent to address the
// case of empty varargs
parent.update(function, order);
parent.update(function);

// then add a new child node with the parameter value type
// and add this function to that node
final ParamType varargSchema = Iterables.getLast(parameters);
final Parameter vararg = new Parameter(varargSchema, true);
final Node leaf = parent.children.computeIfAbsent(vararg, ignored -> new Node());
leaf.update(function, order);
leaf.update(function);

// add a self referential loop for varargs so that we can
// add as many of the same param at the end and still retrieve
// this node
leaf.children.putIfAbsent(vararg, leaf);
}

curr.update(function, order);
curr.update(function);
}

T getFunction(final List<SqlArgument> arguments) {
final List<Node> candidates = new ArrayList<>();

// first try to get the candidates without any implicit casting
getCandidates(arguments, 0, root, candidates, new HashMap<>(), false);
final Optional<T> fun = candidates
.stream()
.max(Node::compare)
.map(node -> node.value);

if (fun.isPresent()) {
return fun.get();
Optional<T> candidate = findMatchingCandidate(arguments, false);
if (candidate.isPresent()) {
return candidate.get();
} else if (!supportsImplicitCasts) {
throw createNoMatchingFunctionException(arguments);
}

// if none were found (candidates is empty) try again with
// implicit casting
getCandidates(arguments, 0, root, candidates, new HashMap<>(), true);
return candidates
.stream()
.max(Node::compare)
.map(node -> node.value)
.orElseThrow(() -> createNoMatchingFunctionException(arguments));
// if none were found (candidate isn't present) try again with implicit casting
candidate = findMatchingCandidate(arguments, true);
if (candidate.isPresent()) {
return candidate.get();
}
throw createNoMatchingFunctionException(arguments);
}

private Optional<T> findMatchingCandidate(
final List<SqlArgument> arguments, final boolean allowCasts) {

final List<Node> candidates = new ArrayList<>();

getCandidates(arguments, 0, root, candidates, new HashMap<>(), allowCasts);
candidates.sort(Node::compare);

final int len = candidates.size();
if (len == 1) {
return Optional.of(candidates.get(0).value);
} else if (len > 1) {
if (candidates.get(len - 1).compare(candidates.get(len - 2)) > 0) {
return Optional.of(candidates.get(len - 1).value);
}
throw createVagueImplicitCastException(arguments);
}

return Optional.empty();
}

private void getCandidates(
Expand Down Expand Up @@ -194,10 +208,8 @@ private void getCandidates(
}
}

private KsqlException createNoMatchingFunctionException(final List<SqlArgument> paramTypes) {
LOG.debug("Current UdfIndex:\n{}", describe());

final String requiredTypes = paramTypes.stream()
private String getParamsAsString(final List<SqlArgument> paramTypes) {
return paramTypes.stream()
.map(argument -> {
if (argument == null) {
return "null";
Expand All @@ -206,10 +218,35 @@ private KsqlException createNoMatchingFunctionException(final List<SqlArgument>
}
})
.collect(Collectors.joining(", ", "(", ")"));
}

final String acceptedTypes = allFunctions.values().stream()
private String getAcceptedTypesAsString() {
return allFunctions.values().stream()
.map(UdfIndex::formatAvailableSignatures)
.collect(Collectors.joining(System.lineSeparator()));
}

private KsqlException createVagueImplicitCastException(final List<SqlArgument> paramTypes) {
LOG.debug("Current UdfIndex:\n{}", describe());
throw new KsqlException("Function '" + udfName
+ "' cannot be resolved due to ambiguous method parameters "
+ getParamsAsString(paramTypes) + "."
+ System.lineSeparator()
+ "Use CAST() to explicitly cast your parameters to one of the supported function calls."
+ System.lineSeparator()
+ "Valid function calls are:"
+ System.lineSeparator()
+ getAcceptedTypesAsString()
+ System.lineSeparator()
+ "For detailed information on a function run: DESCRIBE FUNCTION <Function-Name>;");
}

private KsqlException createNoMatchingFunctionException(final List<SqlArgument> paramTypes) {
LOG.debug("Current UdfIndex:\n{}", describe());

final String requiredTypes = getParamsAsString(paramTypes);

final String acceptedTypes = getAcceptedTypesAsString();

return new KsqlException("Function '" + udfName
+ "' does not accept parameters " + requiredTypes + "."
Expand Down Expand Up @@ -273,17 +310,15 @@ private final class Node {

private final Map<Parameter, Node> children;
private T value;
private int order = 0;

private Node() {
this.children = new HashMap<>();
this.value = null;
}

private void update(final T function, final int order) {
private void update(final T function) {
if (compareFunctions.compare(function, value) > 0) {
value = function;
this.order = order;
}
}

Expand All @@ -307,8 +342,15 @@ public String toString() {
}

int compare(final Node other) {
final int compare = compareFunctions.compare(value, other.value);
return compare == 0 ? -(order - other.order) : compare;
final int compareVal = compareFunctions.compare(value, other.value);
return compareVal == 0 ? countGenerics(other) - countGenerics(this) : compareVal;
}

private int countGenerics(final Node node) {
return node.value.parameters().stream()
.filter(GenericsUtil::hasGenerics)
.mapToInt(p -> 1)
.sum();
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import static io.confluent.ksql.function.KsqlScalarFunction.INTERNAL_PATH;
import static io.confluent.ksql.function.types.ArrayType.of;
import static io.confluent.ksql.schema.ksql.types.SqlTypes.INTEGER;
import static io.confluent.ksql.schema.ksql.types.SqlTypes.BIGINT;
import static java.lang.System.lineSeparator;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
Expand Down Expand Up @@ -562,35 +563,6 @@ public void shouldFindNonVarargWithPartialNullValues() {
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldChooseFirstAddedWithNullValues() {
// Given:
givenFunctions(
function(EXPECTED, false, STRING),
function(OTHER, false, INT)
);

// When:
final KsqlScalarFunction fun = udfIndex.getFunction(Collections.singletonList(null));

// Then:
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldFindVarargWithNullValues() {
// Given:
givenFunctions(
function(EXPECTED, true, STRING_VARARGS)
);

// When:
final KsqlScalarFunction fun = udfIndex.getFunction(Arrays.asList(new SqlArgument[]{null}));

// Then:
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldFindVarargWithSomeNullValues() {
// Given:
Expand Down Expand Up @@ -964,6 +936,61 @@ public void shouldThrowIfNoExactMatchAndImplicitCastDisabled() {
+ "(INTEGER)"));
}

@Test
public void shouldThrowOnAmbiguousImplicitCastWithoutGenerics() {
// Given:
givenFunctions(
function(FIRST_FUNC, false, LONG, LONG),
function(SECOND_FUNC, false, DOUBLE, DOUBLE)
);

// When:
final KsqlException e = assertThrows(KsqlException.class,
() -> udfIndex
.getFunction(ImmutableList.of(SqlArgument.of(INTEGER), SqlArgument.of(BIGINT))));

// Then:
assertThat(e.getMessage(), containsString("Function 'name' cannot be resolved due " +
"to ambiguous method parameters "
+ "(INTEGER, BIGINT)"));
}

@Test
public void shouldFindFewerGenerics() {
// Given:
givenFunctions(
function(EXPECTED, false, INT, GenericType.of("A"), INT),
function(OTHER, false, INT, GenericType.of("A"), GenericType.of("B"))
);

// When:
final KsqlScalarFunction fun = udfIndex.getFunction(ImmutableList
.of(SqlArgument.of(INTEGER), SqlArgument.of(INTEGER), SqlArgument.of(INTEGER)));

// Then:
assertThat(fun.name(), equalTo(EXPECTED));
}

@Test
public void shouldThrowOnAmbiguousImplicitCastWithGenerics() {
// Given:
givenFunctions(
function(FIRST_FUNC, false, LONG, GenericType.of("A"), GenericType.of("B")),
function(SECOND_FUNC, false, DOUBLE, GenericType.of("A"), GenericType.of("B"))
);

// When:
final KsqlException e = assertThrows(KsqlException.class,
() -> udfIndex
.getFunction(ImmutableList
.of(SqlArgument.of(INTEGER), SqlArgument.of(INTEGER), SqlArgument.of(INTEGER))));

// Then:
assertThat(e.getMessage(), containsString("Function 'name' cannot be resolved due " +
"to ambiguous method parameters "
+ "(INTEGER, INTEGER, INTEGER)"));
}

private void givenFunctions(final KsqlScalarFunction... functions) {
Arrays.stream(functions).forEach(udfIndex::addFunction);
}
Expand Down

0 comments on commit cd1a988

Please sign in to comment.