diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/DynamicUdfInvoker.java b/ksql-engine/src/main/java/io/confluent/ksql/function/DynamicUdfInvoker.java index c5504fb43fc0..3edd4636f6b0 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/DynamicUdfInvoker.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/function/DynamicUdfInvoker.java @@ -15,10 +15,10 @@ package io.confluent.ksql.function; +import java.lang.reflect.Array; import java.lang.reflect.GenericArrayType; import java.lang.reflect.Method; import java.lang.reflect.TypeVariable; -import java.util.Arrays; /** * An implementation of UdfInvoker which invokes the UDF dynamically using reflection @@ -50,17 +50,15 @@ public class DynamicUdfInvoker implements UdfInvoker { public Object eval(final Object udf, final Object... args) { try { final Object[] extractedArgs = extractArgs(args); - for (int i = 0; i < extractedArgs.length; i++) { - extractedArgs[i] = - UdfArgCoercer.coerceUdfArgs(extractedArgs[i], method.getParameterTypes()[i], i); - } return method.invoke(udf, extractedArgs); } catch (Exception e) { throw new KsqlFunctionException("Failed to invoke udf " + method, e); } } - // Method.invoke() is a pain and expects any varargs to be packaged up in a further Object[] + /* + Method.invoke() is a pain and expects any varargs to be packaged up in a further Object[] + */ private Object[] extractArgs(final Object... source) { if (!method.isVarArgs()) { return source; @@ -70,8 +68,14 @@ private Object[] extractArgs(final Object... source) { System.arraycopy(source, 0, args, 0, method.getParameterCount() - 1); final int start = method.getParameterCount() - 1; - final Object[] varargs = Arrays.copyOfRange(source, start, source.length); - args[start] = varargs; + final Class componentType = method.getParameterTypes()[start].getComponentType(); + + // Need to convert to array of component type - Method.invoke requires this + final Object val = Array.newInstance(componentType, source.length - start); + for (int i = start; i < source.length; i++) { + Array.set(val, i - start, source[i]); + } + args[start] = val; return args; } diff --git a/ksql-engine/src/main/java/io/confluent/ksql/function/UdfArgCoercer.java b/ksql-engine/src/main/java/io/confluent/ksql/function/UdfArgCoercer.java deleted file mode 100644 index c0e0c68e283a..000000000000 --- a/ksql-engine/src/main/java/io/confluent/ksql/function/UdfArgCoercer.java +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Copyright 2019 Confluent Inc. - * - * Licensed under the Confluent Community License (the "License"); you may not use - * this file except in compliance with the License. You may obtain a copy of the - * License at - * - * http://www.confluent.io/confluent-community-license - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package io.confluent.ksql.function; - -import com.google.common.primitives.Primitives; -import java.lang.reflect.Array; - -final class UdfArgCoercer { - - private UdfArgCoercer() { - } - - static T coerceUdfArgs( - final Object[] args, - final Class clazz, - final int index) { - final Object arg = args[index]; - return coerceUdfArgs(arg, clazz, index); - } - - static T coerceUdfArgs( - final Object arg, - final Class clazz, - final int index - ) { - if (arg == null) { - if (clazz.isPrimitive()) { - throw new KsqlFunctionException( - String.format( - "Can't coerce argument at index %d from null to a primitive type", index)); - } - return null; - } - - if (clazz.isArray()) { - try { - return fromArray(arg, clazz); - } catch (Exception e) { - throw new KsqlFunctionException( - String.format("Couldn't coerce array argument \"args[%d]\" to type %s", index, clazz) - ); - } - } - - // using boxed type is safe: long.class and Long.class are both of type Class - // and this is a no-op for non-primitives - final Class boxedType = Primitives.wrap(clazz); - if (boxedType.isAssignableFrom(arg.getClass())) { - return boxedType.cast(arg); - } else if (arg instanceof String) { - try { - return fromString((String) arg, clazz); - } catch (Exception e) { - throw new KsqlFunctionException( - String.format("Couldn't coerce string argument '\"args[%d]\"' to type %s", - index, clazz)); - } - } else if (arg instanceof Number) { - try { - return fromNumber((Number) arg, boxedType); - } catch (Exception e) { - throw new KsqlFunctionException( - String.format("Couldn't coerce numeric argument '\"args[%d]:(%s) %s\"' to type %s", - index, arg.getClass(), arg, clazz)); - } - } else { - throw new KsqlFunctionException( - String.format("Impossible to coerce (%s) %s into %s", arg.getClass(), arg, clazz)); - } - } - - @SuppressWarnings("unchecked") - private static T fromArray( - final Object args, - final Class arrayType - ) { - if (!args.getClass().isArray()) { - throw new KsqlFunctionException( - String.format("Cannot coerce non-array object %s to %s", args, arrayType)); - } - - final int length = Array.getLength(args); - final Class componentType = arrayType.getComponentType(); - final Object val = Array.newInstance(componentType, length); - for (int i = 0; i < length; i++) { - Array.set(val, i, coerceUdfArgs(Array.get(args, i), componentType, i)); - } - return (T) val; - } - - private static T fromNumber(final Number arg, final Class boxedType) { - if (Integer.class.isAssignableFrom(boxedType)) { - return boxedType.cast(arg.intValue()); - } else if (Long.class.isAssignableFrom(boxedType)) { - return boxedType.cast(arg.longValue()); - } else if (Double.class.isAssignableFrom(boxedType)) { - return boxedType.cast(arg.doubleValue()); - } else if (Float.class.isAssignableFrom(boxedType)) { - return boxedType.cast(arg.floatValue()); - } else if (Byte.class.isAssignableFrom(boxedType)) { - return boxedType.cast(arg.byteValue()); - } else if (Short.class.isAssignableFrom(boxedType)) { - return boxedType.cast(arg.shortValue()); - } - - throw new KsqlFunctionException(String.format("Cannot coerce %s into %s", arg, boxedType)); - } - - @SuppressWarnings("unchecked") - private static T fromString(final String arg, final Class clazz) { - if (Integer.class.isAssignableFrom(Primitives.wrap(clazz))) { - return (T) Integer.valueOf(arg); - } else if (Long.class.isAssignableFrom(Primitives.wrap(clazz))) { - return (T) Long.valueOf(arg); - } else if (Double.class.isAssignableFrom(Primitives.wrap(clazz))) { - return (T) Double.valueOf(arg); - } else if (Float.class.isAssignableFrom(Primitives.wrap(clazz))) { - return (T) Float.valueOf(arg); - } else if (Byte.class.isAssignableFrom(Primitives.wrap(clazz))) { - return (T) Byte.valueOf(arg); - } else if (Short.class.isAssignableFrom(Primitives.wrap(clazz))) { - return (T) Short.valueOf(arg); - } else if (Boolean.class.isAssignableFrom(Primitives.wrap(clazz))) { - return (T) Boolean.valueOf(arg); - } - - throw new KsqlFunctionException(String.format("Cannot coerce %s into %s", arg, clazz)); - } - -} \ No newline at end of file diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/UdfArgCoercerTest.java b/ksql-engine/src/test/java/io/confluent/ksql/function/UdfArgCoercerTest.java deleted file mode 100644 index 3924099c1e95..000000000000 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/UdfArgCoercerTest.java +++ /dev/null @@ -1,259 +0,0 @@ -/* - * Copyright 2019 Confluent Inc. - * - * Licensed under the Confluent Community License (the "License"); you may not use - * this file except in compliance with the License. You may obtain a copy of the - * License at - * - * http://www.confluent.io/confluent-community-license - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package io.confluent.ksql.function; - -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.nullValue; - -import com.google.common.primitives.Primitives; -import java.lang.reflect.Array; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Supplier; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; - -public class UdfArgCoercerTest { - - @Rule - public final ExpectedException expectedException = ExpectedException.none(); - - @Test - public void testCoerceNumbers() { - // Given: - Object[] args = new Object[]{1, 1L, 1d, 1f}; - - // Then: - for (int i = 0; i < args.length; i++) { - assertThat(UdfArgCoercer.coerceUdfArgs(args, int.class, i), equalTo(1)); - assertThat(UdfArgCoercer.coerceUdfArgs(args, Integer.class, i), equalTo(1)); - - assertThat(UdfArgCoercer.coerceUdfArgs(args, long.class, i), equalTo(1L)); - assertThat(UdfArgCoercer.coerceUdfArgs(args, Long.class, i), equalTo(1L)); - - assertThat(UdfArgCoercer.coerceUdfArgs(args, double.class, i), equalTo(1.0)); - assertThat(UdfArgCoercer.coerceUdfArgs(args, Double.class, i), equalTo(1.0)); - } - } - - @Test - public void testCoerceStrings() { - // Given: - Object[] args = new Object[]{"1", "1.2", "true"}; - - // Then: - assertThat(UdfArgCoercer.coerceUdfArgs(args, int.class, 0), equalTo(1)); - assertThat(UdfArgCoercer.coerceUdfArgs(args, Integer.class, 0), equalTo(1)); - - assertThat(UdfArgCoercer.coerceUdfArgs(args, long.class, 0), equalTo(1L)); - assertThat(UdfArgCoercer.coerceUdfArgs(args, Long.class, 0), equalTo(1L)); - - assertThat(UdfArgCoercer.coerceUdfArgs(args, double.class, 1), equalTo(1.2)); - assertThat(UdfArgCoercer.coerceUdfArgs(args, Double.class, 1), equalTo(1.2)); - - assertThat(UdfArgCoercer.coerceUdfArgs(args, boolean.class, 2), is(true)); - assertThat(UdfArgCoercer.coerceUdfArgs(args, boolean.class, 2), is(true)); - } - - @Test - public void testCoerceBoxedNull() { - // Given: - Object[] args = new Object[]{null}; - - // Then: - assertThat(UdfArgCoercer.coerceUdfArgs(args, Integer.class, 0), nullValue()); - assertThat(UdfArgCoercer.coerceUdfArgs(args, Long.class, 0), nullValue()); - assertThat(UdfArgCoercer.coerceUdfArgs(args, Double.class, 0), nullValue()); - assertThat(UdfArgCoercer.coerceUdfArgs(args, String.class, 0), nullValue()); - assertThat(UdfArgCoercer.coerceUdfArgs(args, Boolean.class, 0), nullValue()); - assertThat(UdfArgCoercer.coerceUdfArgs(args, Map.class, 0), nullValue()); - assertThat(UdfArgCoercer.coerceUdfArgs(args, List.class, 0), nullValue()); - assertThat(UdfArgCoercer.coerceUdfArgs(args, Object[].class, 0), nullValue()); - } - - @Test - public void testCoercePrimitiveFailsNull() { - // Given: - Object[] args = new Object[]{null}; - - // Then: - expectedException.expect(KsqlFunctionException.class); - expectedException.expectMessage("from null to a primitive type"); - - // When: - UdfArgCoercer.coerceUdfArgs(args, int.class, 0); - } - - @Test - public void testCoerceObjects() { - // Given: - Object[] args = new Object[]{new ArrayList<>(), new HashMap<>(), ""}; - - // Then: - assertThat(UdfArgCoercer.coerceUdfArgs(args, List.class, 0), equalTo(new ArrayList<>())); - assertThat(UdfArgCoercer.coerceUdfArgs(args, Map.class, 1), equalTo(new HashMap<>())); - assertThat(UdfArgCoercer.coerceUdfArgs(args, String.class, 2), equalTo("")); - } - - @Test - public void shouldCoerceNullArray() { - // Given: - Object[] args = new Object[]{null}; - - // Then: - assertThat(UdfArgCoercer.coerceUdfArgs(args, Object[].class, 0), equalTo(null)); - } - - @Test - public void shouldCoercePrimitiveArrays() { - // Given: - Object[] args = new Object[]{ - new int[]{1}, - new byte[]{1}, - new short[]{1}, - new float[]{1f}, - new double[]{1d}, - new boolean[]{true} - }; - - // Then: - for (int i = 0; i < args.length; i++) { - assertThat(UdfArgCoercer.coerceUdfArgs(args, args[i].getClass(), i), equalTo(args[i])); - } - } - - @Test - public void shouldCoerceBoxedArrays() { - // Given: - Object[] args = new Object[]{ - new Integer[]{1}, - new Byte[]{1}, - new Short[]{1}, - new Float[]{1f}, - new Double[]{1d}, - new Boolean[]{true} - }; - - // Then: - for (int i = 0; i < args.length; i++) { - assertThat(UdfArgCoercer.coerceUdfArgs(args, args[i].getClass(), i), equalTo(args[i])); - } - } - - @Test - public void shouldCoercePrimitiveArrayToBoxed() { - // Given: - Object[] args = new Object[]{ - new int[]{1}, - new byte[]{1}, - new short[]{1}, - new float[]{1f}, - new double[]{1d}, - new boolean[]{true} - }; - - // Then: - for (int i = 0; i < args.length; i++) { - final Class boxed = Primitives.wrap(args[i].getClass().getComponentType()); - final Class boxedArray = Array.newInstance(boxed, 0).getClass(); - assertThat(UdfArgCoercer.coerceUdfArgs(args, boxedArray, i), equalTo(args[i])); - } - } - - @Test - public void shouldCoerceNumberConversionArray() { - // Given: - Object[] args = new Object[]{new int[]{1}}; - - // Then: - assertThat(UdfArgCoercer.coerceUdfArgs(args, double[].class, 0), equalTo(new double[]{1})); - } - - @Test - public void shouldCoerceArrayOfLists() { - // Given: - Object[] args = new Object[]{new List[]{new ArrayList()}}; - - // Then: - assertThat(UdfArgCoercer.coerceUdfArgs(args, List[].class, 0), equalTo(new List[]{new ArrayList<>()})); - } - - @Test - public void shouldCoerceArrayOfMaps() { - // Given: - Object[] args = new Object[]{new Map[]{new HashMap<>()}}; - - // Then: - assertThat(UdfArgCoercer.coerceUdfArgs(args, Map[].class, 0), equalTo(new Map[]{new HashMap<>()})); - } - - @Test - public void shouldNotCoerceNonArrayToArray() { - // Given: - Object[] args = new Object[]{"not an array"}; - - // Expect: - expectedException.expect(KsqlFunctionException.class); - expectedException.expectMessage("Couldn't coerce array"); - - // When: - UdfArgCoercer.coerceUdfArgs(args, Object[].class, 0); - } - - @Test - public void testInvalidStringCoercion() { - // Given: - Object[] args = new Object[]{"not a number"}; - - // Then: - expectedException.expect(KsqlFunctionException.class); - expectedException.expectMessage("Couldn't coerce string"); - - // When: - UdfArgCoercer.coerceUdfArgs(args, int.class, 0); - } - - @Test - public void testInvalidNumberCoercion() { - // Given: - Object[] args = new Object[]{1}; - - // Then: - expectedException.expect(KsqlFunctionException.class); - expectedException.expectMessage("Couldn't coerce numeric"); - - // When: - UdfArgCoercer.coerceUdfArgs(args, Map.class, 0); - } - - @Test - public void testImpossibleCoercion() { - // Given - Object[] args = new Object[]{(Supplier) () -> null}; - - // Then: - expectedException.expect(KsqlFunctionException.class); - expectedException.expectMessage("Impossible to coerce"); - - // When: - UdfArgCoercer.coerceUdfArgs(args, int.class, 0); - } -} \ No newline at end of file diff --git a/ksql-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java b/ksql-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java index dfa112894208..22fe7764051c 100644 --- a/ksql-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java +++ b/ksql-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java @@ -573,90 +573,90 @@ public void shouldEnsureFunctionReturnTypeIsDeepOptional() { } @Test - public void shouldCompileFunctionWithMapArgument() throws Exception { + public void shouldInvokeFunctionWithMapArgument() throws Exception { final UdfInvoker udf = UdfLoader.createUdfInvoker(getClass().getMethod("udf", Map.class)); assertThat(udf.eval(this, Collections.emptyMap()), equalTo("{}")); } @Test - public void shouldCompileFunctionWithListArgument() throws Exception { + public void shouldInvokeFunctionWithListArgument() throws Exception { final UdfInvoker udf = UdfLoader .createUdfInvoker(getClass().getMethod("udf", List.class)); assertThat(udf.eval(this, Collections.emptyList()), equalTo("[]")); } @Test - public void shouldCompileFunctionWithDoubleArgument() throws Exception { + public void shouldInvokeFunctionWithDoubleArgument() throws Exception { final UdfInvoker udf = UdfLoader .createUdfInvoker(getClass().getMethod("udf", Double.class)); assertThat(udf.eval(this, 1.0d), equalTo(1.0)); } @Test - public void shouldCompileFunctionWithIntegerArgument() throws Exception { + public void shouldInvokeFunctionWithIntegerArgument() throws Exception { final UdfInvoker udf = UdfLoader .createUdfInvoker(getClass().getMethod("udf", Integer.class)); assertThat(udf.eval(this, 1), equalTo(1)); } @Test - public void shouldCompileFunctionWithLongArgument() throws Exception { + public void shouldInvokeFunctionWithLongArgument() throws Exception { final UdfInvoker udf = UdfLoader .createUdfInvoker(getClass().getMethod("udf", Long.class)); assertThat(udf.eval(this, 1L), equalTo(1L)); } @Test - public void shouldCompileFunctionWithBooleanArgument() throws Exception { + public void shouldInvokeFunctionWithBooleanArgument() throws Exception { final UdfInvoker udf = UdfLoader .createUdfInvoker(getClass().getMethod("udf", Boolean.class)); assertThat(udf.eval(this, true), equalTo(true)); } @Test - public void shouldCompileFunctionWithIntArgument() throws Exception { + public void shouldInvokeFunctionWithIntArgument() throws Exception { final UdfInvoker udf = UdfLoader .createUdfInvoker(getClass().getMethod("udfPrimitive", int.class)); assertThat(udf.eval(this, 1), equalTo(1)); } @Test - public void shouldCompileFunctionWithIntVarArgs() throws Exception { + public void shouldInvokeFunctionWithIntVarArgs() throws Exception { final UdfInvoker udf = UdfLoader .createUdfInvoker(getClass().getMethod("udfPrimitive", int[].class)); assertThat(udf.eval(this, 1, 1), equalTo(2)); } @Test - public void shouldCompileFunctionWithPrimitiveLongArgument() throws Exception { + public void shouldInvokeFunctionWithPrimitiveLongArgument() throws Exception { final UdfInvoker udf = UdfLoader .createUdfInvoker(getClass().getMethod("udfPrimitive", long.class)); assertThat(udf.eval(this, 1), equalTo(1L)); } @Test - public void shouldCompileFunctionWithPrimitiveDoubleArgument() throws Exception { + public void shouldInvokeFunctionWithPrimitiveDoubleArgument() throws Exception { final UdfInvoker udf = UdfLoader .createUdfInvoker(getClass().getMethod("udfPrimitive", double.class)); assertThat(udf.eval(this, 1), equalTo(1.0)); } @Test - public void shouldCompileFunctionWithPrimitiveBooleanArgument() throws Exception { + public void shouldInvokeFunctionWithPrimitiveBooleanArgument() throws Exception { final UdfInvoker udf = UdfLoader .createUdfInvoker(getClass().getMethod("udfPrimitive", boolean.class)); assertThat(udf.eval(this, true), equalTo(true)); } @Test - public void shouldCompileFunctionWithStringArgument() throws Exception { + public void shouldInvokeFunctionWithStringArgument() throws Exception { final UdfInvoker udf = UdfLoader .createUdfInvoker(getClass().getMethod("udf", String.class)); assertThat(udf.eval(this, "foo"), equalTo("foo")); } @Test - public void shouldCompileFunctionWithStringVarArgs() throws Exception { + public void shouldInvokeFunctionWithStringVarArgs() throws Exception { final UdfInvoker udf = UdfLoader .createUdfInvoker(getClass().getMethod("udf", String[].class)); assertThat(udf.eval(this, "foo", "bar"), equalTo("foobar")); @@ -687,7 +687,7 @@ public void shouldHandleMethodsWithParameterizedGenericArguments() throws Except } @Test - public void shouldCompileUdafWithMethodWithNoArgs() throws Exception { + public void shouldInvokeUdafWithMethodWithNoArgs() throws Exception { final UdafFactoryInvoker creator = createUdfLoader().createUdafFactoryInvoker(TestUdaf.class.getMethod("createSumLong"), FunctionName.of("test-udf"), @@ -700,14 +700,14 @@ public void shouldCompileUdafWithMethodWithNoArgs() throws Exception { } @Test - public void shouldCompileFunctionWithStructReturnValue() throws Exception { + public void shouldInvokeFunctionWithStructReturnValue() throws Exception { final UdfInvoker udf = UdfLoader .createUdfInvoker(getClass().getMethod("udfStruct", String.class)); assertThat(udf.eval(this, "val"), equalTo(new Struct(STRUCT_SCHEMA).put("a", "val"))); } @Test - public void shouldCompileFunctionWithStructParameter() throws Exception { + public void shouldInvokeFunctionWithStructParameter() throws Exception { final UdfInvoker udf = UdfLoader .createUdfInvoker(getClass().getMethod("udfStruct", Struct.class)); assertThat(udf.eval(this, new Struct(STRUCT_SCHEMA).put("a", "val")), equalTo("val")); @@ -728,7 +728,7 @@ public void shouldImplementTableAggregateFunctionWhenTableUdafClass() throws Exc } @Test - public void shouldCompileUdafWhenMethodHasArgs() throws Exception { + public void shouldInvokeUdafWhenMethodHasArgs() throws Exception { final UdafFactoryInvoker creator = createUdfLoader().createUdafFactoryInvoker(TestUdaf.class.getMethod("createSumLengthString", String.class),