From dd3f365eba725c77cca63f9196dac174394ecb58 Mon Sep 17 00:00:00 2001 From: Steven Zhang <35498506+stevenpyzhang@users.noreply.github.com> Date: Fri, 29 Jan 2021 10:56:45 -0800 Subject: [PATCH] feat: add lambda syntax to grammar (#6868) * feat: add lambda syntax to grammar * change g4 * spacing --- .../io/confluent/ksql/util/KsqlConstants.java | 1 + .../ksql/engine/rewrite/AstSanitizer.java | 67 ++++++++++++-- .../rewrite/ExpressionTreeRewriter.java | 21 ++++- .../ksql/engine/rewrite/LambdaContext.java | 45 +++++++++ .../ksql/engine/rewrite/AstSanitizerTest.java | 70 ++++++++++++++ .../rewrite/ExpressionTreeRewriterTest.java | 9 ++ .../engine/rewrite/LambdaContextTest.java | 55 +++++++++++ .../ksql/execution/codegen/CodeGenRunner.java | 7 ++ .../execution/codegen/SqlToJavaVisitor.java | 17 ++++ .../execution/codegen/helpers/LambdaUtil.java | 58 +++++++++++- .../codegen/helpers/TriFunction.java | 31 +++++++ .../formatter/ExpressionFormatter.java | 19 ++++ .../expression/tree/ExpressionVisitor.java | 3 + .../expression/tree/LambdaFunctionCall.java | 91 +++++++++++++++++++ .../expression/tree/LambdaLiteral.java | 65 +++++++++++++ .../tree/TraversalExpressionVisitor.java | 11 +++ .../tree/VisitParentExpressionVisitor.java | 10 ++ .../execution/util/ExpressionTypeManager.java | 23 +++++ .../codegen/helpers/LambdaUtilTest.java | 59 +++++++++++- .../formatter/ExpressionFormatterTest.java | 20 ++++ .../tree/LambdaFunctionCallTest.java | 72 +++++++++++++++ .../io/confluent/ksql/parser/SqlBase.g4 | 9 +- .../io/confluent/ksql/parser/AstBuilder.java | 16 +++- .../confluent/ksql/parser/AstBuilderTest.java | 47 ++++++++++ 24 files changed, 811 insertions(+), 15 deletions(-) create mode 100644 ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/LambdaContext.java create mode 100644 ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/LambdaContextTest.java create mode 100644 ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/TriFunction.java create mode 100644 ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/LambdaFunctionCall.java create mode 100644 ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/LambdaLiteral.java create mode 100644 ksqldb-execution/src/test/java/io/confluent/ksql/execution/expression/tree/LambdaFunctionCallTest.java diff --git a/ksqldb-common/src/main/java/io/confluent/ksql/util/KsqlConstants.java b/ksqldb-common/src/main/java/io/confluent/ksql/util/KsqlConstants.java index 78e6864b6e58..b42c4bd2abb5 100644 --- a/ksqldb-common/src/main/java/io/confluent/ksql/util/KsqlConstants.java +++ b/ksqldb-common/src/main/java/io/confluent/ksql/util/KsqlConstants.java @@ -40,6 +40,7 @@ private KsqlConstants() { public static final String DOT = "."; public static final String STRUCT_FIELD_REF = "->"; + public static final String LAMBDA_FUNCTION = "=>"; public static final String KSQL_SERVICE_ID_METRICS_TAG = "ksql_service_id"; diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/AstSanitizer.java b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/AstSanitizer.java index 49590d4a56bc..c75db2dcad80 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/AstSanitizer.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/AstSanitizer.java @@ -17,10 +17,13 @@ import static java.util.Objects.requireNonNull; +import com.google.common.collect.ImmutableSet; import io.confluent.ksql.analyzer.Analysis.AliasedDataSource; import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context; import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.Expression; +import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall; +import io.confluent.ksql.execution.expression.tree.LambdaLiteral; import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor; @@ -41,9 +44,12 @@ import io.confluent.ksql.util.AmbiguousColumnException; import io.confluent.ksql.util.KsqlException; import io.confluent.ksql.util.UnknownSourceException; +import java.util.HashSet; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.function.BiFunction; +import java.util.stream.Collectors; /** * Validate and clean ASTs generated from externally supplied statements @@ -55,6 +61,7 @@ *
  • No unqualified column references are ambiguous
  • *
  • All single column select items have an alias set * that ensures they are unique across all sources
  • + *
  • Lambda arguments don't overlap with column references
  • * */ public final class AstSanitizer { @@ -71,15 +78,15 @@ public static Statement sanitize(final Statement node, final MetaStore metaStore final ExpressionRewriterPlugin expressionRewriterPlugin = new ExpressionRewriterPlugin(dataSourceExtractor); - final BiFunction expressionRewriter = - (e, v) -> ExpressionTreeRewriter.rewriteWith(expressionRewriterPlugin::process, e, v); + final BiFunction expressionRewriter = + (e,v) -> ExpressionTreeRewriter.rewriteWith(expressionRewriterPlugin::process, e, v); return (Statement) new StatementRewriter<>(expressionRewriter, rewriterPlugin::process) - .rewrite(node, null); + .rewrite(node, new SanitizerContext()); } private static final class RewriterPlugin extends - AstVisitor, StatementRewriter.Context> { + AstVisitor, StatementRewriter.Context> { private final MetaStore metaStore; private final DataSourceExtractor dataSourceExtractor; @@ -102,7 +109,7 @@ private static final class RewriterPlugin extends @Override protected Optional visitInsertInto( final InsertInto node, - final StatementRewriter.Context ctx + final StatementRewriter.Context ctx ) { final DataSource target = metaStore.getSource(node.getTarget()); if (target == null) { @@ -129,7 +136,7 @@ protected Optional visitInsertInto( @Override protected Optional visitSingleColumn( final SingleColumn singleColumn, - final StatementRewriter.Context ctx + final StatementRewriter.Context ctx ) { if (singleColumn.getAlias().isPresent()) { return Optional.empty(); @@ -157,7 +164,8 @@ protected Optional visitSingleColumn( } private static final class ExpressionRewriterPlugin extends - VisitParentExpressionVisitor, Context> { + VisitParentExpressionVisitor, + ExpressionTreeRewriter.Context> { private final DataSourceExtractor dataSourceExtractor; @@ -169,9 +177,13 @@ private static final class ExpressionRewriterPlugin extends @Override public Optional visitUnqualifiedColumnReference( final UnqualifiedColumnReferenceExp expression, - final Context ctx + final ExpressionTreeRewriter.Context ctx ) { final ColumnName columnName = expression.getColumnName(); + if (ctx.getContext().getLambdaArgs().size() > 0 + && ctx.getContext().getLambdaArgs().contains(columnName.text())) { + return Optional.of(new LambdaLiteral(columnName.text())); + } final List sourceNames = dataSourceExtractor.getSourcesFor(columnName); @@ -192,5 +204,44 @@ public Optional visitUnqualifiedColumnReference( ) ); } + + @Override + public Optional visitLambdaExpression( + final LambdaFunctionCall expression, + final ExpressionTreeRewriter.Context ctx + ) { + dataSourceExtractor.getAllSources().forEach(aliasedDataSource -> { + for (String argument : expression.getArguments()) { + if (aliasedDataSource.getDataSource().getSchema().columns().stream() + .map(column -> column.name().text()).collect(Collectors.toList()) + .contains(argument)) { + throw new KsqlException( + String.format( + "Lambda function argument can't be a column name: %s", argument)); + } + } + }); + + ctx.getContext().addLambdaArg(expression.getArguments()); + return visitExpression(expression, ctx); + } + } + + private static class SanitizerContext { + final Set lambdaArgs = new HashSet<>(); + + private void addLambdaArg(final List newArguments) { + final int previousLambdaArgumentsLength = lambdaArgs.size(); + lambdaArgs.addAll(newArguments); + if (new HashSet<>(lambdaArgs).size() + < previousLambdaArgumentsLength + 1) { + throw new KsqlException( + "Reusing lambda arguments in nested lambda is not allowed"); + } + } + + private Set getLambdaArgs() { + return ImmutableSet.copyOf(lambdaArgs); + } } } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriter.java b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriter.java index f651c5f4cad2..f4907c55a686 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriter.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriter.java @@ -39,6 +39,8 @@ import io.confluent.ksql.execution.expression.tree.IntegerLiteral; import io.confluent.ksql.execution.expression.tree.IsNotNullPredicate; import io.confluent.ksql.execution.expression.tree.IsNullPredicate; +import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall; +import io.confluent.ksql.execution.expression.tree.LambdaLiteral; import io.confluent.ksql.execution.expression.tree.LikePredicate; import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression; import io.confluent.ksql.execution.expression.tree.LongLiteral; @@ -72,7 +74,7 @@ * @param A context type to be passed through to the plugin. */ public final class ExpressionTreeRewriter { - + public static final class Context { private final C context; private final ExpressionVisitor rewriter; @@ -465,6 +467,18 @@ public Expression visitCast(final Cast node, final C context) { return new Cast(node.getLocation(), expression, type); } + @Override + public Expression visitLambdaExpression(final LambdaFunctionCall node, final C context) { + final Optional result + = plugin.apply(node, new Context<>(context, this)); + if (result.isPresent()) { + return result.get(); + } + + final Expression expression = rewriter.apply(node.getBody(), context); + return new LambdaFunctionCall(node.getLocation(), node.getArguments(), expression); + } + @Override public Expression visitBooleanLiteral(final BooleanLiteral node, final C context) { return plugin.apply(node, new Context<>(context, this)).orElse(node); @@ -485,6 +499,11 @@ public Expression visitLongLiteral(final LongLiteral node, final C context) { return plugin.apply(node, new Context<>(context, this)).orElse(node); } + @Override + public Expression visitLambdaLiteral(final LambdaLiteral node, final C context) { + return plugin.apply(node, new Context<>(context, this)).orElse(node); + } + @Override public Expression visitNullLiteral(final NullLiteral node, final C context) { return plugin.apply(node, new Context<>(context, this)).orElse(node); diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/LambdaContext.java b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/LambdaContext.java new file mode 100644 index 000000000000..8d7e2c5c1a89 --- /dev/null +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/LambdaContext.java @@ -0,0 +1,45 @@ +/* + * Copyright 2021 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.engine.rewrite; + +import io.confluent.ksql.util.KsqlException; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; + +public class LambdaContext { + private final List lambdaArguments; + + public LambdaContext(final List lambdaArguments) { + this.lambdaArguments = new ArrayList<>( + Objects.requireNonNull(lambdaArguments, "lambdaArguments")); + } + + public void addLambdaArguments(final List newArguments) { + final int previousLambdaArgumentsLength = lambdaArguments.size(); + lambdaArguments.addAll(newArguments); + if (new HashSet<>(lambdaArguments).size() + < previousLambdaArgumentsLength + newArguments.size()) { + throw new KsqlException("Duplicate lambda arguments are not allowed."); + } + } + + public List getLambdaArguments() { + return lambdaArguments; + } +} diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/AstSanitizerTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/AstSanitizerTest.java index beb66d929467..861063257122 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/AstSanitizerTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/AstSanitizerTest.java @@ -24,10 +24,16 @@ import static org.mockito.Mockito.mock; import com.google.common.collect.ImmutableList; +import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.expression.tree.IntegerLiteral; +import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall; +import io.confluent.ksql.execution.expression.tree.LambdaLiteral; import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.metastore.MetaStore; import io.confluent.ksql.name.ColumnName; +import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.name.SourceName; import io.confluent.ksql.parser.AstBuilder; import io.confluent.ksql.parser.DefaultKsqlParser; @@ -36,6 +42,7 @@ import io.confluent.ksql.parser.tree.Select; import io.confluent.ksql.parser.tree.SingleColumn; import io.confluent.ksql.parser.tree.Statement; +import io.confluent.ksql.schema.Operator; import io.confluent.ksql.util.KsqlException; import io.confluent.ksql.util.MetaStoreFixture; import java.util.List; @@ -196,6 +203,69 @@ public void shouldAddQualifierForJoinColumnReferenceFromRight() { )))); } + @Test + public void shouldSanitizeLambdaArguments() { + // Given: + final Statement stmt = givenQuery( + "SELECT TRANSFORM_ARRAY(Col4, X => X + 5) FROM TEST1;"); + + // When: + final Query result = (Query) AstSanitizer.sanitize(stmt, META_STORE); + + // Then: + assertThat(result.getSelect(), is(new Select(ImmutableList.of( + new SingleColumn( + new FunctionCall( + FunctionName.of("TRANSFORM_ARRAY"), + ImmutableList.of( + column(TEST1_NAME, "COL4"), + new LambdaFunctionCall( + ImmutableList.of("X"), + new ArithmeticBinaryExpression( + Operator.ADD, + new LambdaLiteral("X"), + new IntegerLiteral(5)) + ) + ) + ), + Optional.of(ColumnName.of("KSQL_COL_0"))) + )))); + + } + + @Test + public void shouldThrowOnColumnNamesUsedForLambdaArguments() { + // Given: + final Statement stmt = givenQuery( + "SELECT TRANSFORM_ARRAY(Col4, Col0 => Col0 + 5) FROM TEST1;"); + + final Exception e = assertThrows( + KsqlException.class, + () -> AstSanitizer.sanitize(stmt, META_STORE) + ); + + // Then: + assertThat(e.getMessage(), + containsString("Lambda function argument can't be a column name: COL0")); + + } + + @Test + public void shouldThrowOnDuplicateLambdaArguments() { + // Given: + final Statement stmt = givenQuery( + "SELECT TRANSFORM_ARRAY(Col4, X => TRANSFORM_ARRAY(Col4, X => X)) FROM TEST1;"); + + final Exception e = assertThrows( + KsqlException.class, + () -> AstSanitizer.sanitize(stmt, META_STORE) + ); + + // Then: + assertThat(e.getMessage(), + containsString("Reusing lambda arguments in nested lambda is not allowed")); + } + @Test public void shouldThrowOnAmbiguousQualifierForJoinColumnReference() { // Given: diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriterTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriterTest.java index 5381a194e54c..d57e311f034d 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriterTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriterTest.java @@ -182,6 +182,15 @@ public void shouldRewriteArithmeticBinaryUsingPlugin() { shouldRewriteUsingPlugin(parsed); } + @Test + public void shouldRewriteLambdaFunctionUsingPlugin() { + // Given: + final Expression parsed = parseExpression("TRANSFORM_ARRAY(Array[1,2], X => X + Col0)"); + + // When/Then: + shouldRewriteUsingPlugin(parsed); + } + @Test public void shouldRewriteBetweenPredicate() { // Given: diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/LambdaContextTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/LambdaContextTest.java new file mode 100644 index 000000000000..f7bcf1adbddb --- /dev/null +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/LambdaContextTest.java @@ -0,0 +1,55 @@ +/* + * Copyright 2021 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package io.confluent.ksql.engine.rewrite; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThrows; + +import com.google.common.collect.ImmutableList; +import io.confluent.ksql.parser.tree.Statement; +import io.confluent.ksql.util.KsqlException; +import org.junit.Test; + +public class LambdaContextTest { + @Test + public void shouldThrowIfSourceDoesNotExist() { + // Given: + final LambdaContext context = new LambdaContext(ImmutableList.of("X")); + + // When: + context.addLambdaArguments(ImmutableList.of("Z", "Y")); + + // Then: + assertThat(context.getLambdaArguments(), is(ImmutableList.of("X", "Z", "Y"))); + } + + @Test + public void shouldThrowIfLambdaArgumentAlreadyUsed() { + // Given: + final LambdaContext context = new LambdaContext(ImmutableList.of("X")); + + // When: + final Exception e = assertThrows( + KsqlException.class, + () -> context.addLambdaArguments(ImmutableList.of("X", "Y")) + ); + + // Then: + assertThat(e.getMessage(), containsString( + "Duplicate lambda arguments are not allowed.")); + } +} \ No newline at end of file diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java index 9f06cf25e57d..d4bcc8b8b0d7 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java @@ -24,6 +24,7 @@ import io.confluent.ksql.execution.expression.tree.DereferenceExpression; import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall; import io.confluent.ksql.execution.expression.tree.LikePredicate; import io.confluent.ksql.execution.expression.tree.SubscriptExpression; import io.confluent.ksql.execution.expression.tree.TraversalExpressionVisitor; @@ -254,6 +255,12 @@ public Void visitDereferenceExpression(final DereferenceExpression node, final V return null; } + @Override + public Void visitLambdaExpression(final LambdaFunctionCall node, final Void context) { + process(node.getBody(), null); + return null; + } + private void addRequiredColumn(final ColumnName columnName) { final Column column = schema.findValueColumn(columnName) .orElseThrow(() -> new KsqlException( diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java index ce8e98cd4422..a2fda62354ae 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java @@ -54,6 +54,8 @@ import io.confluent.ksql.execution.expression.tree.IntegerLiteral; import io.confluent.ksql.execution.expression.tree.IsNotNullPredicate; import io.confluent.ksql.execution.expression.tree.IsNullPredicate; +import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall; +import io.confluent.ksql.execution.expression.tree.LambdaLiteral; import io.confluent.ksql.execution.expression.tree.LikePredicate; import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression; import io.confluent.ksql.execution.expression.tree.LongLiteral; @@ -104,6 +106,7 @@ import java.util.Map; import java.util.Objects; import java.util.StringJoiner; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -118,6 +121,7 @@ public class SqlToJavaVisitor { public static final List JAVA_IMPORTS = ImmutableList.of( "io.confluent.ksql.execution.codegen.helpers.ArrayAccess", "io.confluent.ksql.execution.codegen.helpers.SearchedCaseFunction", + "io.confluent.ksql.execution.codegen.helpers.TriFunction", "io.confluent.ksql.execution.codegen.helpers.SearchedCaseFunction.LazyWhenClause", "java.sql.Timestamp", "java.util.Arrays", @@ -130,6 +134,7 @@ public class SqlToJavaVisitor { "com.google.common.collect.ImmutableMap", "java.util.function.Supplier", Function.class.getCanonicalName(), + BiFunction.class.getCanonicalName(), DecimalUtil.class.getCanonicalName(), BigDecimal.class.getCanonicalName(), MathContext.class.getCanonicalName(), @@ -344,6 +349,18 @@ public Pair visitNullLiteral(final NullLiteral node, final Void return new Pair<>("null", null); } + @Override + public Pair visitLambdaExpression( + final LambdaFunctionCall lambdaFunctionCall, final Void context) { + return visitUnsupported(lambdaFunctionCall); + } + + @Override + public Pair visitLambdaLiteral( + final LambdaLiteral lambdaLiteral, final Void context) { + return visitUnsupported(lambdaLiteral); + } + @Override public Pair visitUnqualifiedColumnReference( final UnqualifiedColumnReferenceExp node, diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java index 95bbf7dfd3b9..9e86988fad01 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java @@ -15,6 +15,11 @@ package io.confluent.ksql.execution.codegen.helpers; +import io.confluent.ksql.util.KsqlException; +import io.confluent.ksql.util.Pair; + +import java.util.List; + /** * Functions to help with generating code to work around the fact that the script engine doesn't * support lambdas. @@ -29,7 +34,7 @@ private LambdaUtil() { * * @param argName the name of the single argument the {@code lambdaBody} expects. * @param argType the type of the single argument the {@code lambdaBody} expects. - * @param lambdaBody the body of the lambda. It will find the n + * @param lambdaBody the body of the lambda. * @return code to instantiate the function. */ public static String function( @@ -38,12 +43,59 @@ public static String function( final String lambdaBody ) { final String javaType = argType.getSimpleName(); - return "new Function() {\n" + final String function = "new Function() {\n" + " @Override\n" + " public Object apply(Object arg) {\n" + " " + javaType + " " + argName + " = (" + javaType + ") arg;\n" - + " return " + lambdaBody + ";\n" + + " return " + lambdaBody + ";\n" + + " }\n" + + "}"; + return function; + } + + /** + * Generate code to build a {@link java.util.function.Function}. + * + * @param argList a list of lambda arguments that the {@code lambdaBody} expects. + * The type is paired with each argument. + * @param lambdaBody the body of the lambda. + * @return code to instantiate the function. + */ + // CHECKSTYLE_RULES.OFF: FinalLocalVariable + public static String function( + final List>> argList, + final String lambdaBody + ) { + final StringBuilder arguments = new StringBuilder(); + int i = 0; + for (final Pair> argPair : argList) { + i++; + final String javaType = argPair.right.getSimpleName(); + arguments.append( + " " + javaType + " " + argPair.left + " = (" + javaType + ") arg" + i + ";\n"); + } + String functionType; + String functionApply; + if (argList.size() == 1) { + functionType = "Function()"; + functionApply = " public Object apply(Object arg) {\n"; + } else if (argList.size() == 2) { + functionType = "BiFunction()"; + functionApply = " public Object apply(Object arg1, Object arg2) {\n"; + } else if (argList.size() == 3) { + functionType = "TriFunction()"; + functionApply = " public Object apply(Object arg1, Object arg2, Object arg3) {\n"; + } else { + throw new KsqlException("Unsupported number of lambda arguments."); + } + + final String function = "new " + functionType + " {\n" + + " @Override\n" + + functionApply + + arguments.toString() + + " return " + lambdaBody + ";\n" + " }\n" + "}"; + return function; } } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/TriFunction.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/TriFunction.java new file mode 100644 index 000000000000..ce4848a6eaad --- /dev/null +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/TriFunction.java @@ -0,0 +1,31 @@ +/* + * Copyright 2021 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.codegen.helpers; + +import java.util.Objects; +import java.util.function.Function; + +@FunctionalInterface +public interface TriFunction { + + R apply(A a, B b, C c); + + default TriFunction andThen( + Function after) { + Objects.requireNonNull(after); + return (A a, B b, C c) -> after.apply(apply(a, b, c)); + } +} diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatter.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatter.java index 9df086eb36bb..9046c554f0a2 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatter.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatter.java @@ -40,6 +40,8 @@ import io.confluent.ksql.execution.expression.tree.IntegerLiteral; import io.confluent.ksql.execution.expression.tree.IsNotNullPredicate; import io.confluent.ksql.execution.expression.tree.IsNullPredicate; +import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall; +import io.confluent.ksql.execution.expression.tree.LambdaLiteral; import io.confluent.ksql.execution.expression.tree.LikePredicate; import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression; import io.confluent.ksql.execution.expression.tree.LongLiteral; @@ -104,6 +106,11 @@ public String visitStringLiteral(final StringLiteral node, final Context context return formatStringLiteral(node.getValue()); } + @Override + public String visitLambdaLiteral(final LambdaLiteral node, final Context context) { + return String.valueOf(node.getValue()); + } + @Override public String visitSubscriptExpression(final SubscriptExpression node, final Context context) { return process(node.getBase(), context) @@ -387,6 +394,18 @@ public String visitInListExpression(final InListExpression node, final Context c return "(" + joinExpressions(node.getValues(), context) + ")"; } + @Override + public String visitLambdaExpression( + final LambdaFunctionCall node, final Context context) { + final StringBuilder builder = new StringBuilder(); + + builder.append('('); + Joiner.on(", ").appendTo(builder, node.getArguments()); + builder.append(") => "); + builder.append(process(node.getBody(), context)); + return builder.toString(); + } + private String formatBinaryExpression( final String operator, final Expression left, final Expression right, final Context context ) { diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/ExpressionVisitor.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/ExpressionVisitor.java index 0783e4075154..1fe2da36a063 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/ExpressionVisitor.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/ExpressionVisitor.java @@ -87,4 +87,7 @@ default R process(final Expression node, final C context) { R visitWhenClause(WhenClause exp, C context); + R visitLambdaExpression(LambdaFunctionCall exp, C context); + + R visitLambdaLiteral(LambdaLiteral exp, C context); } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/LambdaFunctionCall.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/LambdaFunctionCall.java new file mode 100644 index 000000000000..7730a0bff15d --- /dev/null +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/LambdaFunctionCall.java @@ -0,0 +1,91 @@ +/* + * Copyright 2021 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"; you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.expression.tree; + +import static java.util.Objects.requireNonNull; + +import com.google.common.collect.ImmutableList; +import io.confluent.ksql.parser.NodeLocation; + +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + + +public class LambdaFunctionCall extends Expression { + + private final ImmutableList arguments; + private final Expression body; + + public LambdaFunctionCall( + final List name, + final Expression body + ) { + this(Optional.empty(), name, body); + } + + public LambdaFunctionCall( + final Optional location, + final List arguments, + final Expression body + ) { + super(location); + this.arguments = ImmutableList.copyOf(requireNonNull(arguments, "arguments")); + if (arguments.size() == 0) { + throw new IllegalArgumentException( + String.format("Lambda expression must have at least 1 argument. => %s", body.toString())); + } + final Set set = new HashSet<>(arguments); + if (set.size() < arguments.size()) { + throw new IllegalArgumentException( + String.format("Lambda arguments have duplicates: %s", arguments.toString())); + } + this.body = requireNonNull(body, "body is null"); + } + + public List getArguments() { + return arguments; + } + + public Expression getBody() { + return body; + } + + @Override + public R accept(final ExpressionVisitor visitor, final C context) { + return visitor.visitLambdaExpression(this, context); + } + + @Override + public boolean equals(final Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + final LambdaFunctionCall that = (LambdaFunctionCall) obj; + return Objects.equals(arguments, that.arguments) + && Objects.equals(body, that.body); + } + + @Override + public int hashCode() { + return Objects.hash(arguments, body); + } +} diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/LambdaLiteral.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/LambdaLiteral.java new file mode 100644 index 000000000000..840cb6d77b7d --- /dev/null +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/LambdaLiteral.java @@ -0,0 +1,65 @@ +/* + * Copyright 2021 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.expression.tree; + +import com.google.errorprone.annotations.Immutable; +import io.confluent.ksql.parser.NodeLocation; + +import java.util.Objects; +import java.util.Optional; + +@Immutable +public class LambdaLiteral extends Literal { + + private final String lambdaCharacter; + + public LambdaLiteral(final String lambdaCharacter) { + this(Optional.empty(), lambdaCharacter); + } + + public LambdaLiteral(final Optional location, final String lambdaCharacter) { + super(location); + this.lambdaCharacter = lambdaCharacter; + } + + @Override + public String getValue() { + return lambdaCharacter; + } + + @Override + public R accept(final ExpressionVisitor visitor, final C context) { + return visitor.visitLambdaLiteral(this, context); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + final LambdaLiteral that = (LambdaLiteral) o; + return lambdaCharacter.equals(that.lambdaCharacter); + } + + @Override + public int hashCode() { + return Objects.hashCode(lambdaCharacter); + } +} diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/TraversalExpressionVisitor.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/TraversalExpressionVisitor.java index a09ab474d984..258e00809887 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/TraversalExpressionVisitor.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/TraversalExpressionVisitor.java @@ -98,6 +98,12 @@ public Void visitFunctionCall(final FunctionCall node, final C context) { return null; } + @Override + public Void visitLambdaExpression(final LambdaFunctionCall node, final C context) { + process(node.getBody(), context); + return null; + } + @Override public Void visitDereferenceExpression(final DereferenceExpression node, final C context) { process(node.getBase(), context); @@ -197,6 +203,11 @@ public Void visitBooleanLiteral(final BooleanLiteral node, final C context) { return null; } + @Override + public Void visitLambdaLiteral(final LambdaLiteral node, final C context) { + return null; + } + @Override public Void visitUnqualifiedColumnReference( final UnqualifiedColumnReferenceExp node, diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/VisitParentExpressionVisitor.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/VisitParentExpressionVisitor.java index e47291858d83..b20cc95310da 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/VisitParentExpressionVisitor.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/VisitParentExpressionVisitor.java @@ -213,4 +213,14 @@ public R visitType(final Type node, final C context) { public R visitCast(final Cast node, final C context) { return visitExpression(node, context); } + + @Override + public R visitLambdaExpression(final LambdaFunctionCall node, final C context) { + return visitExpression(node, context); + } + + @Override + public R visitLambdaLiteral(final LambdaLiteral node, final C context) { + return visitLiteral(node, context); + } } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java index 4ba02c7fef00..8256592e4b1d 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java @@ -37,6 +37,8 @@ import io.confluent.ksql.execution.expression.tree.IntegerLiteral; import io.confluent.ksql.execution.expression.tree.IsNotNullPredicate; import io.confluent.ksql.execution.expression.tree.IsNullPredicate; +import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall; +import io.confluent.ksql.execution.expression.tree.LambdaLiteral; import io.confluent.ksql.execution.expression.tree.LikePredicate; import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression; import io.confluent.ksql.execution.expression.tree.LongLiteral; @@ -136,6 +138,27 @@ public Void visitArithmeticUnary( return null; } + @Override + // CHECKSTYLE_RULES.OFF: TodoComment + public Void visitLambdaExpression( + final LambdaFunctionCall node, final ExpressionTypeContext context + ) { + process(node.getBody(), context); + // TODO: add proper type inference + context.setSqlType(SqlTypes.INTEGER); + return null; + } + + @Override + // CHECKSTYLE_RULES.OFF: TodoComment + public Void visitLambdaLiteral( + final LambdaLiteral node, final ExpressionTypeContext expressionTypeContext + ) { + // TODO: add proper type inference + expressionTypeContext.setSqlType(SqlTypes.INTEGER); + return null; + } + @Override public Void visitNotExpression( final NotExpression node, final ExpressionTypeContext expressionTypeContext diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtilTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtilTest.java index a41d0b988d55..d2d504d85eec 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtilTest.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtilTest.java @@ -19,8 +19,16 @@ import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import com.google.common.collect.ImmutableList; import io.confluent.ksql.execution.codegen.CodeGenTestUtil; + +import java.util.Arrays; +import java.util.List; +import java.util.function.BiFunction; import java.util.function.Function; + +import io.confluent.ksql.util.KsqlException; +import io.confluent.ksql.util.Pair; import org.junit.Test; @SuppressWarnings("unchecked") @@ -34,11 +42,60 @@ public void shouldGenerateFunctionCode() { // When: final String javaCode = LambdaUtil - .function(argName, argType, argName + ".longValue() + 1"); + .function(argName, argType, argName + " + 1"); // Then: final Object result = CodeGenTestUtil.cookAndEval(javaCode, Function.class); assertThat(result, is(instanceOf(Function.class))); assertThat(((Function) result).apply(10L), is(11L)); } + + @Test + public void shouldGenerateBiFunction() { + // Given: + final Pair> argName1 = new Pair<>("fred", Long.class); + final Pair> argName2 = new Pair<>("bob", Long.class); + + final List>> argList = ImmutableList.of(argName1, argName2); + + // When: + final String javaCode = LambdaUtil.function(argList, "fred + bob + 2"); + + // Then: + final Object result = CodeGenTestUtil.cookAndEval(javaCode, BiFunction.class); + assertThat(result, is(instanceOf(BiFunction.class))); + assertThat(((BiFunction) result).apply(10L, 15L), is(27L)); + } + + @Test + public void shouldGenerateTriFunction() { + // Given: + final Pair> argName1 = new Pair<>("fred", Long.class); + final Pair> argName2 = new Pair<>("bob", Long.class); + final Pair> argName3 = new Pair<>("tim", Long.class); + + final List>> argList = ImmutableList.of(argName1, argName2, argName3); + + // When: + final String javaCode = LambdaUtil.function(argList, "fred + bob + tim + 1"); + + // Then: + final Object result = CodeGenTestUtil.cookAndEval(javaCode, TriFunction.class); + assertThat(result, is(instanceOf(TriFunction.class))); + assertThat(((TriFunction) result).apply(10L, 15L, 3L), is(29L)); + } + + @Test(expected= KsqlException.class) + public void shouldThrowOnNonSupportedArguments() { + // Given: + final Pair> argName1 = new Pair<>("fred", Long.class); + final Pair> argName2 = new Pair<>("bob", Long.class); + final Pair> argName3 = new Pair<>("tim", Long.class); + final Pair> argName4 = new Pair<>("hello", Long.class); + + final List>> argList = ImmutableList.of(argName1, argName2, argName3, argName4); + + // When: + LambdaUtil.function(argList, "fred + bob + tim + hello + 1"); + } } \ No newline at end of file diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatterTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatterTest.java index c279cbe01ce5..9eb7b13d5a68 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatterTest.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatterTest.java @@ -41,6 +41,8 @@ import io.confluent.ksql.execution.expression.tree.IntegerLiteral; import io.confluent.ksql.execution.expression.tree.IsNotNullPredicate; import io.confluent.ksql.execution.expression.tree.IsNullPredicate; +import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall; +import io.confluent.ksql.execution.expression.tree.LambdaLiteral; import io.confluent.ksql.execution.expression.tree.LikePredicate; import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression; import io.confluent.ksql.execution.expression.tree.LongLiteral; @@ -200,6 +202,24 @@ public void shouldFormatDereferenceExpression() { assertThat(text, equalTo("'foo'->name")); } + @Test + public void shouldFormatLambdaExpression() { + // Given: + final LambdaFunctionCall expression = new LambdaFunctionCall( + Optional.of(LOCATION), + ImmutableList.of("X", "Y"), + new LogicalBinaryExpression(LogicalBinaryExpression.Type.OR, + new LambdaLiteral("X"), + new LambdaLiteral("Y")) + ); + + // When: + final String text = ExpressionFormatter.formatExpression(expression); + + // Then: + assertThat(text, equalTo("(X, Y) => (X OR Y)")); + } + @Test public void shouldFormatFunctionCallWithCount() { final FunctionCall functionCall = new FunctionCall(FunctionName.of("COUNT"), diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/expression/tree/LambdaFunctionCallTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/expression/tree/LambdaFunctionCallTest.java new file mode 100644 index 000000000000..309785763458 --- /dev/null +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/expression/tree/LambdaFunctionCallTest.java @@ -0,0 +1,72 @@ +/* + * Copyright 2021 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.execution.expression.tree; + +import com.google.common.collect.ImmutableList; +import com.google.common.testing.EqualsTester; +import io.confluent.ksql.parser.NodeLocation; +import java.util.List; +import java.util.Optional; +import org.junit.Test; + +public class LambdaFunctionCallTest { + + public static final NodeLocation SOME_LOCATION = new NodeLocation(0, 0); + public static final NodeLocation OTHER_LOCATION = new NodeLocation(1, 0); + private static final List SOME_ARGUMENTS = ImmutableList.of("X", "Y"); + private static final List OTHER_ARGUMENTS = ImmutableList.of("X"); + private static final Expression SOME_EXPRESSION = new StringLiteral("steven"); + private static final Expression OTHER_EXPRESSION = new StringLiteral("steve"); + + @Test + public void shouldImplementHashCodeAndEqualsProperty() { + new EqualsTester() + .addEqualityGroup( + // Note: At the moment location does not take part in equality testing + new LambdaFunctionCall(SOME_ARGUMENTS, SOME_EXPRESSION), + new LambdaFunctionCall(SOME_ARGUMENTS, SOME_EXPRESSION), + new LambdaFunctionCall(Optional.of(SOME_LOCATION), SOME_ARGUMENTS, SOME_EXPRESSION), + new LambdaFunctionCall(Optional.of(OTHER_LOCATION), SOME_ARGUMENTS, SOME_EXPRESSION) + ) + .addEqualityGroup( + new LambdaFunctionCall(SOME_ARGUMENTS, OTHER_EXPRESSION), + new LambdaFunctionCall(SOME_ARGUMENTS, OTHER_EXPRESSION), + new LambdaFunctionCall(Optional.of(SOME_LOCATION), SOME_ARGUMENTS, OTHER_EXPRESSION), + new LambdaFunctionCall(Optional.of(OTHER_LOCATION), SOME_ARGUMENTS, OTHER_EXPRESSION) + ).addEqualityGroup( + new LambdaFunctionCall(OTHER_ARGUMENTS, SOME_EXPRESSION), + new LambdaFunctionCall(OTHER_ARGUMENTS, SOME_EXPRESSION), + new LambdaFunctionCall(Optional.of(SOME_LOCATION), OTHER_ARGUMENTS, SOME_EXPRESSION), + new LambdaFunctionCall(Optional.of(OTHER_LOCATION), OTHER_ARGUMENTS, SOME_EXPRESSION) + ).addEqualityGroup( + new LambdaFunctionCall(OTHER_ARGUMENTS, OTHER_EXPRESSION), + new LambdaFunctionCall(OTHER_ARGUMENTS, OTHER_EXPRESSION), + new LambdaFunctionCall(Optional.of(SOME_LOCATION), OTHER_ARGUMENTS, OTHER_EXPRESSION), + new LambdaFunctionCall(Optional.of(OTHER_LOCATION), OTHER_ARGUMENTS, OTHER_EXPRESSION) + ) + .testEquals(); + } + + @Test(expected = IllegalArgumentException.class) + public void shouldThrowOnDuplicateArguments() { + new LambdaFunctionCall(ImmutableList.of("X", "X", "Y"), SOME_EXPRESSION); + } + + @Test(expected = IllegalArgumentException.class) + public void shouldThrowOnNoArguments() { + new LambdaFunctionCall(ImmutableList.of(), SOME_EXPRESSION); + } +} \ No newline at end of file diff --git a/ksqldb-parser/src/main/antlr4/io/confluent/ksql/parser/SqlBase.g4 b/ksqldb-parser/src/main/antlr4/io/confluent/ksql/parser/SqlBase.g4 index 98aae3b60be7..6a98e8b789e2 100644 --- a/ksqldb-parser/src/main/antlr4/io/confluent/ksql/parser/SqlBase.g4 +++ b/ksqldb-parser/src/main/antlr4/io/confluent/ksql/parser/SqlBase.g4 @@ -296,7 +296,7 @@ primaryExpression | MAP '(' (expression ASSIGN expression (',' expression ASSIGN expression)*)? ')' #mapConstructor | STRUCT '(' (identifier ASSIGN expression (',' identifier ASSIGN expression)*)? ')' #structConstructor | identifier '(' ASTERISK ')' #functionCall - | identifier'(' (expression (',' expression)*)? ')' #functionCall + | identifier '(' (expression (',' expression)* (',' lambdaFunction)*)? ')' #functionCall | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | identifier '.' identifier #qualifiedColumnReference @@ -346,6 +346,11 @@ identifier | DIGIT_IDENTIFIER #digitIdentifier ; +lambdaFunction + : identifier '=>' expression #lambda + | '(' identifier (',' identifier)* ')' '=>' expression #lambda + ; + variableName : IDENTIFIER ; @@ -545,6 +550,8 @@ CONCAT: '||'; ASSIGN: ':='; STRUCT_FIELD_REF: '->'; +LAMBDA_EXPRESSION: '=>'; + STRING : '\'' ( ~'\'' | '\'\'' )* '\'' ; diff --git a/ksqldb-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java b/ksqldb-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java index a129a858207d..25a4ebffff5d 100644 --- a/ksqldb-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java +++ b/ksqldb-parser/src/main/java/io/confluent/ksql/parser/AstBuilder.java @@ -42,6 +42,7 @@ import io.confluent.ksql.execution.expression.tree.InPredicate; import io.confluent.ksql.execution.expression.tree.IsNotNullPredicate; import io.confluent.ksql.execution.expression.tree.IsNullPredicate; +import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall; import io.confluent.ksql.execution.expression.tree.LikePredicate; import io.confluent.ksql.execution.expression.tree.Literal; import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression; @@ -637,6 +638,17 @@ public Node visitSelectSingle(final SqlBaseParser.SelectSingleContext context) { } } + @Override + public Node visitLambda(final SqlBaseParser.LambdaContext context) { + final List arguments = context.identifier().stream() + .map(ParserUtil::getIdentifierText) + .collect(toList()); + + final Expression body = (Expression) visit(context.expression()); + + return new LambdaFunctionCall(getLocation(context), arguments, body); + } + @Override public Node visitListTopics(final SqlBaseParser.ListTopicsContext context) { return new ListTopics(getLocation(context), @@ -1173,10 +1185,12 @@ public Node visitWhenClause(final SqlBaseParser.WhenClauseContext context) { @Override public Node visitFunctionCall(final SqlBaseParser.FunctionCallContext context) { + final List expressionList = visit(context.expression(), Expression.class); + expressionList.addAll(visit(context.lambdaFunction(), Expression.class)); return new FunctionCall( getLocation(context), FunctionName.of(ParserUtil.getIdentifierText(context.identifier())), - visit(context.expression(), Expression.class) + expressionList ); } diff --git a/ksqldb-parser/src/test/java/io/confluent/ksql/parser/AstBuilderTest.java b/ksqldb-parser/src/test/java/io/confluent/ksql/parser/AstBuilderTest.java index 6015de85b194..f690ba87eb69 100644 --- a/ksqldb-parser/src/test/java/io/confluent/ksql/parser/AstBuilderTest.java +++ b/ksqldb-parser/src/test/java/io/confluent/ksql/parser/AstBuilderTest.java @@ -26,14 +26,20 @@ import static org.mockito.Mockito.mock; import com.google.common.collect.ImmutableList; +import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression; +import io.confluent.ksql.execution.expression.tree.FunctionCall; +import io.confluent.ksql.execution.expression.tree.IntegerLiteral; +import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall; import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.metastore.MetaStore; import io.confluent.ksql.name.ColumnName; +import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.name.SourceName; import io.confluent.ksql.parser.KsqlParser.ParsedStatement; import io.confluent.ksql.parser.SqlBaseParser.SingleStatementContext; +import io.confluent.ksql.parser.exception.ParseFailedException; import io.confluent.ksql.parser.tree.AliasedRelation; import io.confluent.ksql.parser.tree.AllColumns; import io.confluent.ksql.parser.tree.Explain; @@ -43,6 +49,7 @@ import io.confluent.ksql.parser.tree.Select; import io.confluent.ksql.parser.tree.SingleColumn; import io.confluent.ksql.parser.tree.Table; +import io.confluent.ksql.schema.Operator; import io.confluent.ksql.util.KsqlException; import io.confluent.ksql.util.MetaStoreFixture; import java.util.List; @@ -278,6 +285,46 @@ public void shouldIncludeSelectAliasIfPresent() { )))); } + @Test + public void shouldBuildLambdaFunction() { + // Given: + final SingleStatementContext stmt = givenQuery("SELECT TRANSFORM_ARRAY(Col4, X => X + 5) FROM TEST1;"); + + // When: + final Query result = (Query) builder.buildStatement(stmt); + + // Then: + assertThat(result.getSelect(), is(new Select(ImmutableList.of( + new SingleColumn( + new FunctionCall( + FunctionName.of("TRANSFORM_ARRAY"), + ImmutableList.of( + column("COL4"), + new LambdaFunctionCall( + ImmutableList.of("X"), + new ArithmeticBinaryExpression( + Operator.ADD, + column("X"), + new IntegerLiteral(5)) + ) + ) + ), + Optional.empty()) + )))); + } + + @Test + public void shouldNotBuildLambdaFunctionNotLastArgument() { + // Given: + final Exception e = assertThrows( + ParseFailedException.class, + () -> givenQuery("SELECT TRANSFORM_ARRAY(X => X + 5, Col4) FROM TEST1;") + ); + + // Then: + assertThat(e.getMessage(), containsString("mismatched input '=>' expecting {',', ')'}")); + } + @Test public void shouldHandleUnqualifiedSelectStar() { // Given: