From dd3f365eba725c77cca63f9196dac174394ecb58 Mon Sep 17 00:00:00 2001
From: Steven Zhang <35498506+stevenpyzhang@users.noreply.github.com>
Date: Fri, 29 Jan 2021 10:56:45 -0800
Subject: [PATCH] feat: add lambda syntax to grammar (#6868)
* feat: add lambda syntax to grammar
* change g4
* spacing
---
.../io/confluent/ksql/util/KsqlConstants.java | 1 +
.../ksql/engine/rewrite/AstSanitizer.java | 67 ++++++++++++--
.../rewrite/ExpressionTreeRewriter.java | 21 ++++-
.../ksql/engine/rewrite/LambdaContext.java | 45 +++++++++
.../ksql/engine/rewrite/AstSanitizerTest.java | 70 ++++++++++++++
.../rewrite/ExpressionTreeRewriterTest.java | 9 ++
.../engine/rewrite/LambdaContextTest.java | 55 +++++++++++
.../ksql/execution/codegen/CodeGenRunner.java | 7 ++
.../execution/codegen/SqlToJavaVisitor.java | 17 ++++
.../execution/codegen/helpers/LambdaUtil.java | 58 +++++++++++-
.../codegen/helpers/TriFunction.java | 31 +++++++
.../formatter/ExpressionFormatter.java | 19 ++++
.../expression/tree/ExpressionVisitor.java | 3 +
.../expression/tree/LambdaFunctionCall.java | 91 +++++++++++++++++++
.../expression/tree/LambdaLiteral.java | 65 +++++++++++++
.../tree/TraversalExpressionVisitor.java | 11 +++
.../tree/VisitParentExpressionVisitor.java | 10 ++
.../execution/util/ExpressionTypeManager.java | 23 +++++
.../codegen/helpers/LambdaUtilTest.java | 59 +++++++++++-
.../formatter/ExpressionFormatterTest.java | 20 ++++
.../tree/LambdaFunctionCallTest.java | 72 +++++++++++++++
.../io/confluent/ksql/parser/SqlBase.g4 | 9 +-
.../io/confluent/ksql/parser/AstBuilder.java | 16 +++-
.../confluent/ksql/parser/AstBuilderTest.java | 47 ++++++++++
24 files changed, 811 insertions(+), 15 deletions(-)
create mode 100644 ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/LambdaContext.java
create mode 100644 ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/LambdaContextTest.java
create mode 100644 ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/TriFunction.java
create mode 100644 ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/LambdaFunctionCall.java
create mode 100644 ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/LambdaLiteral.java
create mode 100644 ksqldb-execution/src/test/java/io/confluent/ksql/execution/expression/tree/LambdaFunctionCallTest.java
diff --git a/ksqldb-common/src/main/java/io/confluent/ksql/util/KsqlConstants.java b/ksqldb-common/src/main/java/io/confluent/ksql/util/KsqlConstants.java
index 78e6864b6e58..b42c4bd2abb5 100644
--- a/ksqldb-common/src/main/java/io/confluent/ksql/util/KsqlConstants.java
+++ b/ksqldb-common/src/main/java/io/confluent/ksql/util/KsqlConstants.java
@@ -40,6 +40,7 @@ private KsqlConstants() {
public static final String DOT = ".";
public static final String STRUCT_FIELD_REF = "->";
+ public static final String LAMBDA_FUNCTION = "=>";
public static final String KSQL_SERVICE_ID_METRICS_TAG = "ksql_service_id";
diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/AstSanitizer.java b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/AstSanitizer.java
index 49590d4a56bc..c75db2dcad80 100644
--- a/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/AstSanitizer.java
+++ b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/AstSanitizer.java
@@ -17,10 +17,13 @@
import static java.util.Objects.requireNonNull;
+import com.google.common.collect.ImmutableSet;
import io.confluent.ksql.analyzer.Analysis.AliasedDataSource;
import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.Expression;
+import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall;
+import io.confluent.ksql.execution.expression.tree.LambdaLiteral;
import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor;
@@ -41,9 +44,12 @@
import io.confluent.ksql.util.AmbiguousColumnException;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.UnknownSourceException;
+import java.util.HashSet;
import java.util.List;
import java.util.Optional;
+import java.util.Set;
import java.util.function.BiFunction;
+import java.util.stream.Collectors;
/**
* Validate and clean ASTs generated from externally supplied statements
@@ -55,6 +61,7 @@
*
No unqualified column references are ambiguous
* All single column select items have an alias set
* that ensures they are unique across all sources
+ * Lambda arguments don't overlap with column references
*
*/
public final class AstSanitizer {
@@ -71,15 +78,15 @@ public static Statement sanitize(final Statement node, final MetaStore metaStore
final ExpressionRewriterPlugin expressionRewriterPlugin =
new ExpressionRewriterPlugin(dataSourceExtractor);
- final BiFunction expressionRewriter =
- (e, v) -> ExpressionTreeRewriter.rewriteWith(expressionRewriterPlugin::process, e, v);
+ final BiFunction expressionRewriter =
+ (e,v) -> ExpressionTreeRewriter.rewriteWith(expressionRewriterPlugin::process, e, v);
return (Statement) new StatementRewriter<>(expressionRewriter, rewriterPlugin::process)
- .rewrite(node, null);
+ .rewrite(node, new SanitizerContext());
}
private static final class RewriterPlugin extends
- AstVisitor, StatementRewriter.Context> {
+ AstVisitor, StatementRewriter.Context> {
private final MetaStore metaStore;
private final DataSourceExtractor dataSourceExtractor;
@@ -102,7 +109,7 @@ private static final class RewriterPlugin extends
@Override
protected Optional visitInsertInto(
final InsertInto node,
- final StatementRewriter.Context ctx
+ final StatementRewriter.Context ctx
) {
final DataSource target = metaStore.getSource(node.getTarget());
if (target == null) {
@@ -129,7 +136,7 @@ protected Optional visitInsertInto(
@Override
protected Optional visitSingleColumn(
final SingleColumn singleColumn,
- final StatementRewriter.Context ctx
+ final StatementRewriter.Context ctx
) {
if (singleColumn.getAlias().isPresent()) {
return Optional.empty();
@@ -157,7 +164,8 @@ protected Optional visitSingleColumn(
}
private static final class ExpressionRewriterPlugin extends
- VisitParentExpressionVisitor, Context> {
+ VisitParentExpressionVisitor,
+ ExpressionTreeRewriter.Context> {
private final DataSourceExtractor dataSourceExtractor;
@@ -169,9 +177,13 @@ private static final class ExpressionRewriterPlugin extends
@Override
public Optional visitUnqualifiedColumnReference(
final UnqualifiedColumnReferenceExp expression,
- final Context ctx
+ final ExpressionTreeRewriter.Context ctx
) {
final ColumnName columnName = expression.getColumnName();
+ if (ctx.getContext().getLambdaArgs().size() > 0
+ && ctx.getContext().getLambdaArgs().contains(columnName.text())) {
+ return Optional.of(new LambdaLiteral(columnName.text()));
+ }
final List sourceNames = dataSourceExtractor.getSourcesFor(columnName);
@@ -192,5 +204,44 @@ public Optional visitUnqualifiedColumnReference(
)
);
}
+
+ @Override
+ public Optional visitLambdaExpression(
+ final LambdaFunctionCall expression,
+ final ExpressionTreeRewriter.Context ctx
+ ) {
+ dataSourceExtractor.getAllSources().forEach(aliasedDataSource -> {
+ for (String argument : expression.getArguments()) {
+ if (aliasedDataSource.getDataSource().getSchema().columns().stream()
+ .map(column -> column.name().text()).collect(Collectors.toList())
+ .contains(argument)) {
+ throw new KsqlException(
+ String.format(
+ "Lambda function argument can't be a column name: %s", argument));
+ }
+ }
+ });
+
+ ctx.getContext().addLambdaArg(expression.getArguments());
+ return visitExpression(expression, ctx);
+ }
+ }
+
+ private static class SanitizerContext {
+ final Set lambdaArgs = new HashSet<>();
+
+ private void addLambdaArg(final List newArguments) {
+ final int previousLambdaArgumentsLength = lambdaArgs.size();
+ lambdaArgs.addAll(newArguments);
+ if (new HashSet<>(lambdaArgs).size()
+ < previousLambdaArgumentsLength + 1) {
+ throw new KsqlException(
+ "Reusing lambda arguments in nested lambda is not allowed");
+ }
+ }
+
+ private Set getLambdaArgs() {
+ return ImmutableSet.copyOf(lambdaArgs);
+ }
}
}
diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriter.java b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriter.java
index f651c5f4cad2..f4907c55a686 100644
--- a/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriter.java
+++ b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriter.java
@@ -39,6 +39,8 @@
import io.confluent.ksql.execution.expression.tree.IntegerLiteral;
import io.confluent.ksql.execution.expression.tree.IsNotNullPredicate;
import io.confluent.ksql.execution.expression.tree.IsNullPredicate;
+import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall;
+import io.confluent.ksql.execution.expression.tree.LambdaLiteral;
import io.confluent.ksql.execution.expression.tree.LikePredicate;
import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression;
import io.confluent.ksql.execution.expression.tree.LongLiteral;
@@ -72,7 +74,7 @@
* @param A context type to be passed through to the plugin.
*/
public final class ExpressionTreeRewriter {
-
+
public static final class Context {
private final C context;
private final ExpressionVisitor rewriter;
@@ -465,6 +467,18 @@ public Expression visitCast(final Cast node, final C context) {
return new Cast(node.getLocation(), expression, type);
}
+ @Override
+ public Expression visitLambdaExpression(final LambdaFunctionCall node, final C context) {
+ final Optional result
+ = plugin.apply(node, new Context<>(context, this));
+ if (result.isPresent()) {
+ return result.get();
+ }
+
+ final Expression expression = rewriter.apply(node.getBody(), context);
+ return new LambdaFunctionCall(node.getLocation(), node.getArguments(), expression);
+ }
+
@Override
public Expression visitBooleanLiteral(final BooleanLiteral node, final C context) {
return plugin.apply(node, new Context<>(context, this)).orElse(node);
@@ -485,6 +499,11 @@ public Expression visitLongLiteral(final LongLiteral node, final C context) {
return plugin.apply(node, new Context<>(context, this)).orElse(node);
}
+ @Override
+ public Expression visitLambdaLiteral(final LambdaLiteral node, final C context) {
+ return plugin.apply(node, new Context<>(context, this)).orElse(node);
+ }
+
@Override
public Expression visitNullLiteral(final NullLiteral node, final C context) {
return plugin.apply(node, new Context<>(context, this)).orElse(node);
diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/LambdaContext.java b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/LambdaContext.java
new file mode 100644
index 000000000000..8d7e2c5c1a89
--- /dev/null
+++ b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/rewrite/LambdaContext.java
@@ -0,0 +1,45 @@
+/*
+ * Copyright 2021 Confluent Inc.
+ *
+ * Licensed under the Confluent Community License (the "License"); you may not use
+ * this file except in compliance with the License. You may obtain a copy of the
+ * License at
+ *
+ * http://www.confluent.io/confluent-community-license
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package io.confluent.ksql.engine.rewrite;
+
+import io.confluent.ksql.util.KsqlException;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Objects;
+
+public class LambdaContext {
+ private final List lambdaArguments;
+
+ public LambdaContext(final List lambdaArguments) {
+ this.lambdaArguments = new ArrayList<>(
+ Objects.requireNonNull(lambdaArguments, "lambdaArguments"));
+ }
+
+ public void addLambdaArguments(final List newArguments) {
+ final int previousLambdaArgumentsLength = lambdaArguments.size();
+ lambdaArguments.addAll(newArguments);
+ if (new HashSet<>(lambdaArguments).size()
+ < previousLambdaArgumentsLength + newArguments.size()) {
+ throw new KsqlException("Duplicate lambda arguments are not allowed.");
+ }
+ }
+
+ public List getLambdaArguments() {
+ return lambdaArguments;
+ }
+}
diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/AstSanitizerTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/AstSanitizerTest.java
index beb66d929467..861063257122 100644
--- a/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/AstSanitizerTest.java
+++ b/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/AstSanitizerTest.java
@@ -24,10 +24,16 @@
import static org.mockito.Mockito.mock;
import com.google.common.collect.ImmutableList;
+import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression;
+import io.confluent.ksql.execution.expression.tree.FunctionCall;
+import io.confluent.ksql.execution.expression.tree.IntegerLiteral;
+import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall;
+import io.confluent.ksql.execution.expression.tree.LambdaLiteral;
import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.name.ColumnName;
+import io.confluent.ksql.name.FunctionName;
import io.confluent.ksql.name.SourceName;
import io.confluent.ksql.parser.AstBuilder;
import io.confluent.ksql.parser.DefaultKsqlParser;
@@ -36,6 +42,7 @@
import io.confluent.ksql.parser.tree.Select;
import io.confluent.ksql.parser.tree.SingleColumn;
import io.confluent.ksql.parser.tree.Statement;
+import io.confluent.ksql.schema.Operator;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.MetaStoreFixture;
import java.util.List;
@@ -196,6 +203,69 @@ public void shouldAddQualifierForJoinColumnReferenceFromRight() {
))));
}
+ @Test
+ public void shouldSanitizeLambdaArguments() {
+ // Given:
+ final Statement stmt = givenQuery(
+ "SELECT TRANSFORM_ARRAY(Col4, X => X + 5) FROM TEST1;");
+
+ // When:
+ final Query result = (Query) AstSanitizer.sanitize(stmt, META_STORE);
+
+ // Then:
+ assertThat(result.getSelect(), is(new Select(ImmutableList.of(
+ new SingleColumn(
+ new FunctionCall(
+ FunctionName.of("TRANSFORM_ARRAY"),
+ ImmutableList.of(
+ column(TEST1_NAME, "COL4"),
+ new LambdaFunctionCall(
+ ImmutableList.of("X"),
+ new ArithmeticBinaryExpression(
+ Operator.ADD,
+ new LambdaLiteral("X"),
+ new IntegerLiteral(5))
+ )
+ )
+ ),
+ Optional.of(ColumnName.of("KSQL_COL_0")))
+ ))));
+
+ }
+
+ @Test
+ public void shouldThrowOnColumnNamesUsedForLambdaArguments() {
+ // Given:
+ final Statement stmt = givenQuery(
+ "SELECT TRANSFORM_ARRAY(Col4, Col0 => Col0 + 5) FROM TEST1;");
+
+ final Exception e = assertThrows(
+ KsqlException.class,
+ () -> AstSanitizer.sanitize(stmt, META_STORE)
+ );
+
+ // Then:
+ assertThat(e.getMessage(),
+ containsString("Lambda function argument can't be a column name: COL0"));
+
+ }
+
+ @Test
+ public void shouldThrowOnDuplicateLambdaArguments() {
+ // Given:
+ final Statement stmt = givenQuery(
+ "SELECT TRANSFORM_ARRAY(Col4, X => TRANSFORM_ARRAY(Col4, X => X)) FROM TEST1;");
+
+ final Exception e = assertThrows(
+ KsqlException.class,
+ () -> AstSanitizer.sanitize(stmt, META_STORE)
+ );
+
+ // Then:
+ assertThat(e.getMessage(),
+ containsString("Reusing lambda arguments in nested lambda is not allowed"));
+ }
+
@Test
public void shouldThrowOnAmbiguousQualifierForJoinColumnReference() {
// Given:
diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriterTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriterTest.java
index 5381a194e54c..d57e311f034d 100644
--- a/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriterTest.java
+++ b/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/ExpressionTreeRewriterTest.java
@@ -182,6 +182,15 @@ public void shouldRewriteArithmeticBinaryUsingPlugin() {
shouldRewriteUsingPlugin(parsed);
}
+ @Test
+ public void shouldRewriteLambdaFunctionUsingPlugin() {
+ // Given:
+ final Expression parsed = parseExpression("TRANSFORM_ARRAY(Array[1,2], X => X + Col0)");
+
+ // When/Then:
+ shouldRewriteUsingPlugin(parsed);
+ }
+
@Test
public void shouldRewriteBetweenPredicate() {
// Given:
diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/LambdaContextTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/LambdaContextTest.java
new file mode 100644
index 000000000000..f7bcf1adbddb
--- /dev/null
+++ b/ksqldb-engine/src/test/java/io/confluent/ksql/engine/rewrite/LambdaContextTest.java
@@ -0,0 +1,55 @@
+/*
+ * Copyright 2021 Confluent Inc.
+ *
+ * Licensed under the Confluent Community License (the "License"); you may not use
+ * this file except in compliance with the License. You may obtain a copy of the
+ * License at
+ *
+ * http://www.confluent.io/confluent-community-license
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+package io.confluent.ksql.engine.rewrite;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.is;
+import static org.junit.Assert.assertThrows;
+
+import com.google.common.collect.ImmutableList;
+import io.confluent.ksql.parser.tree.Statement;
+import io.confluent.ksql.util.KsqlException;
+import org.junit.Test;
+
+public class LambdaContextTest {
+ @Test
+ public void shouldThrowIfSourceDoesNotExist() {
+ // Given:
+ final LambdaContext context = new LambdaContext(ImmutableList.of("X"));
+
+ // When:
+ context.addLambdaArguments(ImmutableList.of("Z", "Y"));
+
+ // Then:
+ assertThat(context.getLambdaArguments(), is(ImmutableList.of("X", "Z", "Y")));
+ }
+
+ @Test
+ public void shouldThrowIfLambdaArgumentAlreadyUsed() {
+ // Given:
+ final LambdaContext context = new LambdaContext(ImmutableList.of("X"));
+
+ // When:
+ final Exception e = assertThrows(
+ KsqlException.class,
+ () -> context.addLambdaArguments(ImmutableList.of("X", "Y"))
+ );
+
+ // Then:
+ assertThat(e.getMessage(), containsString(
+ "Duplicate lambda arguments are not allowed."));
+ }
+}
\ No newline at end of file
diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java
index 9f06cf25e57d..d4bcc8b8b0d7 100644
--- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java
+++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/CodeGenRunner.java
@@ -24,6 +24,7 @@
import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
+import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall;
import io.confluent.ksql.execution.expression.tree.LikePredicate;
import io.confluent.ksql.execution.expression.tree.SubscriptExpression;
import io.confluent.ksql.execution.expression.tree.TraversalExpressionVisitor;
@@ -254,6 +255,12 @@ public Void visitDereferenceExpression(final DereferenceExpression node, final V
return null;
}
+ @Override
+ public Void visitLambdaExpression(final LambdaFunctionCall node, final Void context) {
+ process(node.getBody(), null);
+ return null;
+ }
+
private void addRequiredColumn(final ColumnName columnName) {
final Column column = schema.findValueColumn(columnName)
.orElseThrow(() -> new KsqlException(
diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java
index ce8e98cd4422..a2fda62354ae 100644
--- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java
+++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/SqlToJavaVisitor.java
@@ -54,6 +54,8 @@
import io.confluent.ksql.execution.expression.tree.IntegerLiteral;
import io.confluent.ksql.execution.expression.tree.IsNotNullPredicate;
import io.confluent.ksql.execution.expression.tree.IsNullPredicate;
+import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall;
+import io.confluent.ksql.execution.expression.tree.LambdaLiteral;
import io.confluent.ksql.execution.expression.tree.LikePredicate;
import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression;
import io.confluent.ksql.execution.expression.tree.LongLiteral;
@@ -104,6 +106,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.StringJoiner;
+import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
@@ -118,6 +121,7 @@ public class SqlToJavaVisitor {
public static final List JAVA_IMPORTS = ImmutableList.of(
"io.confluent.ksql.execution.codegen.helpers.ArrayAccess",
"io.confluent.ksql.execution.codegen.helpers.SearchedCaseFunction",
+ "io.confluent.ksql.execution.codegen.helpers.TriFunction",
"io.confluent.ksql.execution.codegen.helpers.SearchedCaseFunction.LazyWhenClause",
"java.sql.Timestamp",
"java.util.Arrays",
@@ -130,6 +134,7 @@ public class SqlToJavaVisitor {
"com.google.common.collect.ImmutableMap",
"java.util.function.Supplier",
Function.class.getCanonicalName(),
+ BiFunction.class.getCanonicalName(),
DecimalUtil.class.getCanonicalName(),
BigDecimal.class.getCanonicalName(),
MathContext.class.getCanonicalName(),
@@ -344,6 +349,18 @@ public Pair visitNullLiteral(final NullLiteral node, final Void
return new Pair<>("null", null);
}
+ @Override
+ public Pair visitLambdaExpression(
+ final LambdaFunctionCall lambdaFunctionCall, final Void context) {
+ return visitUnsupported(lambdaFunctionCall);
+ }
+
+ @Override
+ public Pair visitLambdaLiteral(
+ final LambdaLiteral lambdaLiteral, final Void context) {
+ return visitUnsupported(lambdaLiteral);
+ }
+
@Override
public Pair visitUnqualifiedColumnReference(
final UnqualifiedColumnReferenceExp node,
diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java
index 95bbf7dfd3b9..9e86988fad01 100644
--- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java
+++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtil.java
@@ -15,6 +15,11 @@
package io.confluent.ksql.execution.codegen.helpers;
+import io.confluent.ksql.util.KsqlException;
+import io.confluent.ksql.util.Pair;
+
+import java.util.List;
+
/**
* Functions to help with generating code to work around the fact that the script engine doesn't
* support lambdas.
@@ -29,7 +34,7 @@ private LambdaUtil() {
*
* @param argName the name of the single argument the {@code lambdaBody} expects.
* @param argType the type of the single argument the {@code lambdaBody} expects.
- * @param lambdaBody the body of the lambda. It will find the n
+ * @param lambdaBody the body of the lambda.
* @return code to instantiate the function.
*/
public static String function(
@@ -38,12 +43,59 @@ public static String function(
final String lambdaBody
) {
final String javaType = argType.getSimpleName();
- return "new Function() {\n"
+ final String function = "new Function() {\n"
+ " @Override\n"
+ " public Object apply(Object arg) {\n"
+ " " + javaType + " " + argName + " = (" + javaType + ") arg;\n"
- + " return " + lambdaBody + ";\n"
+ + " return " + lambdaBody + ";\n"
+ + " }\n"
+ + "}";
+ return function;
+ }
+
+ /**
+ * Generate code to build a {@link java.util.function.Function}.
+ *
+ * @param argList a list of lambda arguments that the {@code lambdaBody} expects.
+ * The type is paired with each argument.
+ * @param lambdaBody the body of the lambda.
+ * @return code to instantiate the function.
+ */
+ // CHECKSTYLE_RULES.OFF: FinalLocalVariable
+ public static String function(
+ final List>> argList,
+ final String lambdaBody
+ ) {
+ final StringBuilder arguments = new StringBuilder();
+ int i = 0;
+ for (final Pair> argPair : argList) {
+ i++;
+ final String javaType = argPair.right.getSimpleName();
+ arguments.append(
+ " " + javaType + " " + argPair.left + " = (" + javaType + ") arg" + i + ";\n");
+ }
+ String functionType;
+ String functionApply;
+ if (argList.size() == 1) {
+ functionType = "Function()";
+ functionApply = " public Object apply(Object arg) {\n";
+ } else if (argList.size() == 2) {
+ functionType = "BiFunction()";
+ functionApply = " public Object apply(Object arg1, Object arg2) {\n";
+ } else if (argList.size() == 3) {
+ functionType = "TriFunction()";
+ functionApply = " public Object apply(Object arg1, Object arg2, Object arg3) {\n";
+ } else {
+ throw new KsqlException("Unsupported number of lambda arguments.");
+ }
+
+ final String function = "new " + functionType + " {\n"
+ + " @Override\n"
+ + functionApply
+ + arguments.toString()
+ + " return " + lambdaBody + ";\n"
+ " }\n"
+ "}";
+ return function;
}
}
diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/TriFunction.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/TriFunction.java
new file mode 100644
index 000000000000..ce4848a6eaad
--- /dev/null
+++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/codegen/helpers/TriFunction.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright 2021 Confluent Inc.
+ *
+ * Licensed under the Confluent Community License (the "License"); you may not use
+ * this file except in compliance with the License. You may obtain a copy of the
+ * License at
+ *
+ * http://www.confluent.io/confluent-community-license
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package io.confluent.ksql.execution.codegen.helpers;
+
+import java.util.Objects;
+import java.util.function.Function;
+
+@FunctionalInterface
+public interface TriFunction {
+
+ R apply(A a, B b, C c);
+
+ default TriFunction andThen(
+ Function super R, ? extends V> after) {
+ Objects.requireNonNull(after);
+ return (A a, B b, C c) -> after.apply(apply(a, b, c));
+ }
+}
diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatter.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatter.java
index 9df086eb36bb..9046c554f0a2 100644
--- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatter.java
+++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/formatter/ExpressionFormatter.java
@@ -40,6 +40,8 @@
import io.confluent.ksql.execution.expression.tree.IntegerLiteral;
import io.confluent.ksql.execution.expression.tree.IsNotNullPredicate;
import io.confluent.ksql.execution.expression.tree.IsNullPredicate;
+import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall;
+import io.confluent.ksql.execution.expression.tree.LambdaLiteral;
import io.confluent.ksql.execution.expression.tree.LikePredicate;
import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression;
import io.confluent.ksql.execution.expression.tree.LongLiteral;
@@ -104,6 +106,11 @@ public String visitStringLiteral(final StringLiteral node, final Context context
return formatStringLiteral(node.getValue());
}
+ @Override
+ public String visitLambdaLiteral(final LambdaLiteral node, final Context context) {
+ return String.valueOf(node.getValue());
+ }
+
@Override
public String visitSubscriptExpression(final SubscriptExpression node, final Context context) {
return process(node.getBase(), context)
@@ -387,6 +394,18 @@ public String visitInListExpression(final InListExpression node, final Context c
return "(" + joinExpressions(node.getValues(), context) + ")";
}
+ @Override
+ public String visitLambdaExpression(
+ final LambdaFunctionCall node, final Context context) {
+ final StringBuilder builder = new StringBuilder();
+
+ builder.append('(');
+ Joiner.on(", ").appendTo(builder, node.getArguments());
+ builder.append(") => ");
+ builder.append(process(node.getBody(), context));
+ return builder.toString();
+ }
+
private String formatBinaryExpression(
final String operator, final Expression left, final Expression right, final Context context
) {
diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/ExpressionVisitor.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/ExpressionVisitor.java
index 0783e4075154..1fe2da36a063 100644
--- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/ExpressionVisitor.java
+++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/ExpressionVisitor.java
@@ -87,4 +87,7 @@ default R process(final Expression node, final C context) {
R visitWhenClause(WhenClause exp, C context);
+ R visitLambdaExpression(LambdaFunctionCall exp, C context);
+
+ R visitLambdaLiteral(LambdaLiteral exp, C context);
}
diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/LambdaFunctionCall.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/LambdaFunctionCall.java
new file mode 100644
index 000000000000..7730a0bff15d
--- /dev/null
+++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/LambdaFunctionCall.java
@@ -0,0 +1,91 @@
+/*
+ * Copyright 2021 Confluent Inc.
+ *
+ * Licensed under the Confluent Community License (the "License"; you may not use
+ * this file except in compliance with the License. You may obtain a copy of the
+ * License at
+ *
+ * http://www.confluent.io/confluent-community-license
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package io.confluent.ksql.execution.expression.tree;
+
+import static java.util.Objects.requireNonNull;
+
+import com.google.common.collect.ImmutableList;
+import io.confluent.ksql.parser.NodeLocation;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+
+
+public class LambdaFunctionCall extends Expression {
+
+ private final ImmutableList arguments;
+ private final Expression body;
+
+ public LambdaFunctionCall(
+ final List name,
+ final Expression body
+ ) {
+ this(Optional.empty(), name, body);
+ }
+
+ public LambdaFunctionCall(
+ final Optional location,
+ final List arguments,
+ final Expression body
+ ) {
+ super(location);
+ this.arguments = ImmutableList.copyOf(requireNonNull(arguments, "arguments"));
+ if (arguments.size() == 0) {
+ throw new IllegalArgumentException(
+ String.format("Lambda expression must have at least 1 argument. => %s", body.toString()));
+ }
+ final Set set = new HashSet<>(arguments);
+ if (set.size() < arguments.size()) {
+ throw new IllegalArgumentException(
+ String.format("Lambda arguments have duplicates: %s", arguments.toString()));
+ }
+ this.body = requireNonNull(body, "body is null");
+ }
+
+ public List getArguments() {
+ return arguments;
+ }
+
+ public Expression getBody() {
+ return body;
+ }
+
+ @Override
+ public R accept(final ExpressionVisitor visitor, final C context) {
+ return visitor.visitLambdaExpression(this, context);
+ }
+
+ @Override
+ public boolean equals(final Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (obj == null || getClass() != obj.getClass()) {
+ return false;
+ }
+ final LambdaFunctionCall that = (LambdaFunctionCall) obj;
+ return Objects.equals(arguments, that.arguments)
+ && Objects.equals(body, that.body);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(arguments, body);
+ }
+}
diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/LambdaLiteral.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/LambdaLiteral.java
new file mode 100644
index 000000000000..840cb6d77b7d
--- /dev/null
+++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/LambdaLiteral.java
@@ -0,0 +1,65 @@
+/*
+ * Copyright 2021 Confluent Inc.
+ *
+ * Licensed under the Confluent Community License (the "License"); you may not use
+ * this file except in compliance with the License. You may obtain a copy of the
+ * License at
+ *
+ * http://www.confluent.io/confluent-community-license
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+
+package io.confluent.ksql.execution.expression.tree;
+
+import com.google.errorprone.annotations.Immutable;
+import io.confluent.ksql.parser.NodeLocation;
+
+import java.util.Objects;
+import java.util.Optional;
+
+@Immutable
+public class LambdaLiteral extends Literal {
+
+ private final String lambdaCharacter;
+
+ public LambdaLiteral(final String lambdaCharacter) {
+ this(Optional.empty(), lambdaCharacter);
+ }
+
+ public LambdaLiteral(final Optional location, final String lambdaCharacter) {
+ super(location);
+ this.lambdaCharacter = lambdaCharacter;
+ }
+
+ @Override
+ public String getValue() {
+ return lambdaCharacter;
+ }
+
+ @Override
+ public R accept(final ExpressionVisitor visitor, final C context) {
+ return visitor.visitLambdaLiteral(this, context);
+ }
+
+ @Override
+ public boolean equals(final Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ final LambdaLiteral that = (LambdaLiteral) o;
+ return lambdaCharacter.equals(that.lambdaCharacter);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(lambdaCharacter);
+ }
+}
diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/TraversalExpressionVisitor.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/TraversalExpressionVisitor.java
index a09ab474d984..258e00809887 100644
--- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/TraversalExpressionVisitor.java
+++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/TraversalExpressionVisitor.java
@@ -98,6 +98,12 @@ public Void visitFunctionCall(final FunctionCall node, final C context) {
return null;
}
+ @Override
+ public Void visitLambdaExpression(final LambdaFunctionCall node, final C context) {
+ process(node.getBody(), context);
+ return null;
+ }
+
@Override
public Void visitDereferenceExpression(final DereferenceExpression node, final C context) {
process(node.getBase(), context);
@@ -197,6 +203,11 @@ public Void visitBooleanLiteral(final BooleanLiteral node, final C context) {
return null;
}
+ @Override
+ public Void visitLambdaLiteral(final LambdaLiteral node, final C context) {
+ return null;
+ }
+
@Override
public Void visitUnqualifiedColumnReference(
final UnqualifiedColumnReferenceExp node,
diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/VisitParentExpressionVisitor.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/VisitParentExpressionVisitor.java
index e47291858d83..b20cc95310da 100644
--- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/VisitParentExpressionVisitor.java
+++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/expression/tree/VisitParentExpressionVisitor.java
@@ -213,4 +213,14 @@ public R visitType(final Type node, final C context) {
public R visitCast(final Cast node, final C context) {
return visitExpression(node, context);
}
+
+ @Override
+ public R visitLambdaExpression(final LambdaFunctionCall node, final C context) {
+ return visitExpression(node, context);
+ }
+
+ @Override
+ public R visitLambdaLiteral(final LambdaLiteral node, final C context) {
+ return visitLiteral(node, context);
+ }
}
diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java
index 4ba02c7fef00..8256592e4b1d 100644
--- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java
+++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ExpressionTypeManager.java
@@ -37,6 +37,8 @@
import io.confluent.ksql.execution.expression.tree.IntegerLiteral;
import io.confluent.ksql.execution.expression.tree.IsNotNullPredicate;
import io.confluent.ksql.execution.expression.tree.IsNullPredicate;
+import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall;
+import io.confluent.ksql.execution.expression.tree.LambdaLiteral;
import io.confluent.ksql.execution.expression.tree.LikePredicate;
import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression;
import io.confluent.ksql.execution.expression.tree.LongLiteral;
@@ -136,6 +138,27 @@ public Void visitArithmeticUnary(
return null;
}
+ @Override
+ // CHECKSTYLE_RULES.OFF: TodoComment
+ public Void visitLambdaExpression(
+ final LambdaFunctionCall node, final ExpressionTypeContext context
+ ) {
+ process(node.getBody(), context);
+ // TODO: add proper type inference
+ context.setSqlType(SqlTypes.INTEGER);
+ return null;
+ }
+
+ @Override
+ // CHECKSTYLE_RULES.OFF: TodoComment
+ public Void visitLambdaLiteral(
+ final LambdaLiteral node, final ExpressionTypeContext expressionTypeContext
+ ) {
+ // TODO: add proper type inference
+ expressionTypeContext.setSqlType(SqlTypes.INTEGER);
+ return null;
+ }
+
@Override
public Void visitNotExpression(
final NotExpression node, final ExpressionTypeContext expressionTypeContext
diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtilTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtilTest.java
index a41d0b988d55..d2d504d85eec 100644
--- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtilTest.java
+++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/codegen/helpers/LambdaUtilTest.java
@@ -19,8 +19,16 @@
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
+import com.google.common.collect.ImmutableList;
import io.confluent.ksql.execution.codegen.CodeGenTestUtil;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.function.BiFunction;
import java.util.function.Function;
+
+import io.confluent.ksql.util.KsqlException;
+import io.confluent.ksql.util.Pair;
import org.junit.Test;
@SuppressWarnings("unchecked")
@@ -34,11 +42,60 @@ public void shouldGenerateFunctionCode() {
// When:
final String javaCode = LambdaUtil
- .function(argName, argType, argName + ".longValue() + 1");
+ .function(argName, argType, argName + " + 1");
// Then:
final Object result = CodeGenTestUtil.cookAndEval(javaCode, Function.class);
assertThat(result, is(instanceOf(Function.class)));
assertThat(((Function