Skip to content

Commit

Permalink
feat: add lambda syntax to grammar (#6868)
Browse files Browse the repository at this point in the history
* feat: add lambda syntax to grammar

* change g4

* spacing
  • Loading branch information
stevenpyzhang committed Jan 29, 2021
1 parent 03ad537 commit dd3f365
Show file tree
Hide file tree
Showing 24 changed files with 811 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ private KsqlConstants() {

public static final String DOT = ".";
public static final String STRUCT_FIELD_REF = "->";
public static final String LAMBDA_FUNCTION = "=>";

public static final String KSQL_SERVICE_ID_METRICS_TAG = "ksql_service_id";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

import static java.util.Objects.requireNonNull;

import com.google.common.collect.ImmutableSet;
import io.confluent.ksql.analyzer.Analysis.AliasedDataSource;
import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall;
import io.confluent.ksql.execution.expression.tree.LambdaLiteral;
import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor;
Expand All @@ -41,9 +44,12 @@
import io.confluent.ksql.util.AmbiguousColumnException;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.UnknownSourceException;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

/**
* Validate and clean ASTs generated from externally supplied statements
Expand All @@ -55,6 +61,7 @@
* <li>No unqualified column references are ambiguous</li>
* <li>All single column select items have an alias set
* that ensures they are unique across all sources</li>
* <li>Lambda arguments don't overlap with column references</li>
* </ol>
*/
public final class AstSanitizer {
Expand All @@ -71,15 +78,15 @@ public static Statement sanitize(final Statement node, final MetaStore metaStore
final ExpressionRewriterPlugin expressionRewriterPlugin =
new ExpressionRewriterPlugin(dataSourceExtractor);

final BiFunction<Expression, Void, Expression> expressionRewriter =
(e, v) -> ExpressionTreeRewriter.rewriteWith(expressionRewriterPlugin::process, e, v);
final BiFunction<Expression, SanitizerContext, Expression> expressionRewriter =
(e,v) -> ExpressionTreeRewriter.rewriteWith(expressionRewriterPlugin::process, e, v);

return (Statement) new StatementRewriter<>(expressionRewriter, rewriterPlugin::process)
.rewrite(node, null);
.rewrite(node, new SanitizerContext());
}

private static final class RewriterPlugin extends
AstVisitor<Optional<AstNode>, StatementRewriter.Context<Void>> {
AstVisitor<Optional<AstNode>, StatementRewriter.Context<SanitizerContext>> {

private final MetaStore metaStore;
private final DataSourceExtractor dataSourceExtractor;
Expand All @@ -102,7 +109,7 @@ private static final class RewriterPlugin extends
@Override
protected Optional<AstNode> visitInsertInto(
final InsertInto node,
final StatementRewriter.Context<Void> ctx
final StatementRewriter.Context<SanitizerContext> ctx
) {
final DataSource target = metaStore.getSource(node.getTarget());
if (target == null) {
Expand All @@ -129,7 +136,7 @@ protected Optional<AstNode> visitInsertInto(
@Override
protected Optional<AstNode> visitSingleColumn(
final SingleColumn singleColumn,
final StatementRewriter.Context<Void> ctx
final StatementRewriter.Context<SanitizerContext> ctx
) {
if (singleColumn.getAlias().isPresent()) {
return Optional.empty();
Expand Down Expand Up @@ -157,7 +164,8 @@ protected Optional<AstNode> visitSingleColumn(
}

private static final class ExpressionRewriterPlugin extends
VisitParentExpressionVisitor<Optional<Expression>, Context<Void>> {
VisitParentExpressionVisitor<Optional<Expression>,
ExpressionTreeRewriter.Context<SanitizerContext>> {

private final DataSourceExtractor dataSourceExtractor;

Expand All @@ -169,9 +177,13 @@ private static final class ExpressionRewriterPlugin extends
@Override
public Optional<Expression> visitUnqualifiedColumnReference(
final UnqualifiedColumnReferenceExp expression,
final Context<Void> ctx
final ExpressionTreeRewriter.Context<SanitizerContext> ctx
) {
final ColumnName columnName = expression.getColumnName();
if (ctx.getContext().getLambdaArgs().size() > 0
&& ctx.getContext().getLambdaArgs().contains(columnName.text())) {
return Optional.of(new LambdaLiteral(columnName.text()));
}

final List<SourceName> sourceNames = dataSourceExtractor.getSourcesFor(columnName);

Expand All @@ -192,5 +204,44 @@ public Optional<Expression> visitUnqualifiedColumnReference(
)
);
}

@Override
public Optional<Expression> visitLambdaExpression(
final LambdaFunctionCall expression,
final ExpressionTreeRewriter.Context<SanitizerContext> ctx
) {
dataSourceExtractor.getAllSources().forEach(aliasedDataSource -> {
for (String argument : expression.getArguments()) {
if (aliasedDataSource.getDataSource().getSchema().columns().stream()
.map(column -> column.name().text()).collect(Collectors.toList())
.contains(argument)) {
throw new KsqlException(
String.format(
"Lambda function argument can't be a column name: %s", argument));
}
}
});

ctx.getContext().addLambdaArg(expression.getArguments());
return visitExpression(expression, ctx);
}
}

private static class SanitizerContext {
final Set<String> lambdaArgs = new HashSet<>();

private void addLambdaArg(final List<String> newArguments) {
final int previousLambdaArgumentsLength = lambdaArgs.size();
lambdaArgs.addAll(newArguments);
if (new HashSet<>(lambdaArgs).size()
< previousLambdaArgumentsLength + 1) {
throw new KsqlException(
"Reusing lambda arguments in nested lambda is not allowed");
}
}

private Set<String> getLambdaArgs() {
return ImmutableSet.copyOf(lambdaArgs);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
import io.confluent.ksql.execution.expression.tree.IntegerLiteral;
import io.confluent.ksql.execution.expression.tree.IsNotNullPredicate;
import io.confluent.ksql.execution.expression.tree.IsNullPredicate;
import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall;
import io.confluent.ksql.execution.expression.tree.LambdaLiteral;
import io.confluent.ksql.execution.expression.tree.LikePredicate;
import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression;
import io.confluent.ksql.execution.expression.tree.LongLiteral;
Expand Down Expand Up @@ -72,7 +74,7 @@
* @param <C> A context type to be passed through to the plugin.
*/
public final class ExpressionTreeRewriter<C> {

public static final class Context<C> {
private final C context;
private final ExpressionVisitor<Expression, C> rewriter;
Expand Down Expand Up @@ -465,6 +467,18 @@ public Expression visitCast(final Cast node, final C context) {
return new Cast(node.getLocation(), expression, type);
}

@Override
public Expression visitLambdaExpression(final LambdaFunctionCall node, final C context) {
final Optional<Expression> result
= plugin.apply(node, new Context<>(context, this));
if (result.isPresent()) {
return result.get();
}

final Expression expression = rewriter.apply(node.getBody(), context);
return new LambdaFunctionCall(node.getLocation(), node.getArguments(), expression);
}

@Override
public Expression visitBooleanLiteral(final BooleanLiteral node, final C context) {
return plugin.apply(node, new Context<>(context, this)).orElse(node);
Expand All @@ -485,6 +499,11 @@ public Expression visitLongLiteral(final LongLiteral node, final C context) {
return plugin.apply(node, new Context<>(context, this)).orElse(node);
}

@Override
public Expression visitLambdaLiteral(final LambdaLiteral node, final C context) {
return plugin.apply(node, new Context<>(context, this)).orElse(node);
}

@Override
public Expression visitNullLiteral(final NullLiteral node, final C context) {
return plugin.apply(node, new Context<>(context, this)).orElse(node);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright 2021 Confluent Inc.
*
* Licensed under the Confluent Community License (the "License"); you may not use
* this file except in compliance with the License. You may obtain a copy of the
* License at
*
* http://www.confluent.io/confluent-community-license
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/

package io.confluent.ksql.engine.rewrite;

import io.confluent.ksql.util.KsqlException;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;

public class LambdaContext {
private final List<String> lambdaArguments;

public LambdaContext(final List<String> lambdaArguments) {
this.lambdaArguments = new ArrayList<>(
Objects.requireNonNull(lambdaArguments, "lambdaArguments"));
}

public void addLambdaArguments(final List<String> newArguments) {
final int previousLambdaArgumentsLength = lambdaArguments.size();
lambdaArguments.addAll(newArguments);
if (new HashSet<>(lambdaArguments).size()
< previousLambdaArgumentsLength + newArguments.size()) {
throw new KsqlException("Duplicate lambda arguments are not allowed.");
}
}

public List<String> getLambdaArguments() {
return lambdaArguments;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,16 @@
import static org.mockito.Mockito.mock;

import com.google.common.collect.ImmutableList;
import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.execution.expression.tree.IntegerLiteral;
import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall;
import io.confluent.ksql.execution.expression.tree.LambdaLiteral;
import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.name.FunctionName;
import io.confluent.ksql.name.SourceName;
import io.confluent.ksql.parser.AstBuilder;
import io.confluent.ksql.parser.DefaultKsqlParser;
Expand All @@ -36,6 +42,7 @@
import io.confluent.ksql.parser.tree.Select;
import io.confluent.ksql.parser.tree.SingleColumn;
import io.confluent.ksql.parser.tree.Statement;
import io.confluent.ksql.schema.Operator;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.MetaStoreFixture;
import java.util.List;
Expand Down Expand Up @@ -196,6 +203,69 @@ public void shouldAddQualifierForJoinColumnReferenceFromRight() {
))));
}

@Test
public void shouldSanitizeLambdaArguments() {
// Given:
final Statement stmt = givenQuery(
"SELECT TRANSFORM_ARRAY(Col4, X => X + 5) FROM TEST1;");

// When:
final Query result = (Query) AstSanitizer.sanitize(stmt, META_STORE);

// Then:
assertThat(result.getSelect(), is(new Select(ImmutableList.of(
new SingleColumn(
new FunctionCall(
FunctionName.of("TRANSFORM_ARRAY"),
ImmutableList.of(
column(TEST1_NAME, "COL4"),
new LambdaFunctionCall(
ImmutableList.of("X"),
new ArithmeticBinaryExpression(
Operator.ADD,
new LambdaLiteral("X"),
new IntegerLiteral(5))
)
)
),
Optional.of(ColumnName.of("KSQL_COL_0")))
))));

}

@Test
public void shouldThrowOnColumnNamesUsedForLambdaArguments() {
// Given:
final Statement stmt = givenQuery(
"SELECT TRANSFORM_ARRAY(Col4, Col0 => Col0 + 5) FROM TEST1;");

final Exception e = assertThrows(
KsqlException.class,
() -> AstSanitizer.sanitize(stmt, META_STORE)
);

// Then:
assertThat(e.getMessage(),
containsString("Lambda function argument can't be a column name: COL0"));

}

@Test
public void shouldThrowOnDuplicateLambdaArguments() {
// Given:
final Statement stmt = givenQuery(
"SELECT TRANSFORM_ARRAY(Col4, X => TRANSFORM_ARRAY(Col4, X => X)) FROM TEST1;");

final Exception e = assertThrows(
KsqlException.class,
() -> AstSanitizer.sanitize(stmt, META_STORE)
);

// Then:
assertThat(e.getMessage(),
containsString("Reusing lambda arguments in nested lambda is not allowed"));
}

@Test
public void shouldThrowOnAmbiguousQualifierForJoinColumnReference() {
// Given:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,15 @@ public void shouldRewriteArithmeticBinaryUsingPlugin() {
shouldRewriteUsingPlugin(parsed);
}

@Test
public void shouldRewriteLambdaFunctionUsingPlugin() {
// Given:
final Expression parsed = parseExpression("TRANSFORM_ARRAY(Array[1,2], X => X + Col0)");

// When/Then:
shouldRewriteUsingPlugin(parsed);
}

@Test
public void shouldRewriteBetweenPredicate() {
// Given:
Expand Down

0 comments on commit dd3f365

Please sign in to comment.