Skip to content

Commit

Permalink
feat: implement correct logic for nested lambdas and more complex lam…
Browse files Browse the repository at this point in the history
…bda expressions (#7056)

* feat: implement correct logic for nested lambdas and more complex lambda expressions

* qtt

* update qtt and add more comments to clarify code
  • Loading branch information
stevenpyzhang committed Feb 26, 2021
1 parent 5c6e9cf commit 1a042cd
Show file tree
Hide file tree
Showing 13 changed files with 1,069 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,23 +183,34 @@ public Void visitLikePredicate(final LikePredicate node, final TypeContext conte

@Override
public Void visitFunctionCall(final FunctionCall node, final TypeContext context) {
final List<SqlArgument> argumentTypes = new ArrayList<>();
final FunctionName functionName = node.getName();

// this context gets updated as we process non lambda arguments
final TypeContext currentTypeContext = context.getCopy();

final List<SqlArgument> argumentTypes = new ArrayList<>();
final List<TypeContext> typeContextsForChildren = new ArrayList<>();
final boolean hasLambda = node.hasLambdaFunctionCallArguments();
for (final Expression argExpr : node.getArguments()) {
final TypeContext childContext = context.getCopy();
final TypeContext childContext = TypeContextUtil.contextForExpression(
argExpr, context, currentTypeContext
);
typeContextsForChildren.add(childContext);

// pass a copy of the context to the type checker so that type checking in one
// expression subtree doesn't interfere with type checking in another one
final SqlType resolvedArgType =
expressionTypeManager.getExpressionSqlType(argExpr, childContext);
expressionTypeManager.getExpressionSqlType(argExpr, childContext.getCopy());

if (argExpr instanceof LambdaFunctionCall) {
argumentTypes.add(
SqlArgument.of(
SqlLambda.of(context.getLambdaInputTypes(), childContext.getSqlType())));
SqlLambda.of(currentTypeContext.getLambdaInputTypes(), resolvedArgType)));
} else {
argumentTypes.add(SqlArgument.of(resolvedArgType));
// for lambdas - we save the type information to resolve the lambda generics
if (hasLambda) {
context.visitType(resolvedArgType);
currentTypeContext.visitType(resolvedArgType);
}
}
}
Expand All @@ -211,8 +222,9 @@ public Void visitFunctionCall(final FunctionCall node, final TypeContext context
function.newInstance(ksqlConfig)
);

for (final Expression argExpr : node.getArguments()) {
process(argExpr, context.getCopy());
final List<Expression> arguments = node.getArguments();
for (int i = 0; i < arguments.size(); i++) {
process(arguments.get(i), typeContextsForChildren.get(i));
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,21 +452,33 @@ public Pair<String, SqlType> visitFunctionCall(
final String instanceName = funNameToCodeName.apply(functionName);

final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName());

// this context gets updated as we process non lambda arguments
final TypeContext currentTypeContext = context.getCopy();

final List<SqlArgument> argumentSchemas = new ArrayList<>();
final List<TypeContext> typeContextsForChildren = new ArrayList<>();
final boolean hasLambda = node.hasLambdaFunctionCallArguments();

for (final Expression argExpr : node.getArguments()) {
final TypeContext childContext = context.getCopy();
final TypeContext childContext = TypeContextUtil.contextForExpression(
argExpr, context, currentTypeContext
);
typeContextsForChildren.add(childContext);

// pass a copy of the context to the type checker so that type checking in one
// expression subtree doesn't interfere with type checking in another one
final SqlType resolvedArgType =
expressionTypeManager.getExpressionSqlType(argExpr, childContext);
expressionTypeManager.getExpressionSqlType(argExpr, childContext.getCopy());
if (argExpr instanceof LambdaFunctionCall) {
argumentSchemas.add(
SqlArgument.of(
SqlLambda.of(context.getLambdaInputTypes(), childContext.getSqlType())));
SqlLambda.of(currentTypeContext.getLambdaInputTypes(), resolvedArgType)));
} else {
argumentSchemas.add(SqlArgument.of(resolvedArgType));
// for lambdas - we save the type information to resolve the lambda generics
if (hasLambda) {
context.visitType(resolvedArgType);
currentTypeContext.visitType(resolvedArgType);
}
}
}
Expand Down Expand Up @@ -494,7 +506,7 @@ public Pair<String, SqlType> visitFunctionCall(
}

joiner.add(
process(convertArgument(arg, sqlType, paramType), context.getCopy())
process(convertArgument(arg, sqlType, paramType), typeContextsForChildren.get(i))
.getLeft());
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright 2021 Confluent Inc.
*
* Licensed under the Confluent Community License (the "License"); you may not use
* this file except in compliance with the License. You may obtain a copy of the
* License at
*
* http://www.confluent.io/confluent-community-license
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/

package io.confluent.ksql.execution.codegen;

import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall;

public final class TypeContextUtil {
private TypeContextUtil() {

}

/**
* Returns a copy of the appropriate context to use when processing an expression subtree.
* A copy is required to prevent different subtrees from getting a context that's been
* modified by another subtree. For non-lambdas we want to use the parent context because
* there may be valid overlapping lambda parameter names in different child nodes. For
* lambdas, we want to use the updateContext which has the type information the lambda
* expression body needs.
*
* @param expression the current expression we're processing
* @param parentContext the context passed into the parent node of the expression
* @param updatedContext the context that has been updated as we processed other child
* nodes
* @return a copy of either parent or current type context to be passed to the child
*/
public static TypeContext contextForExpression(
final Expression expression,
final TypeContext parentContext,
final TypeContext updatedContext
) {
if (expression instanceof LambdaFunctionCall) {
return updatedContext.getCopy();
} else {
return parentContext.getCopy();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.confluent.ksql.execution.codegen.TypeContext;
import io.confluent.ksql.execution.codegen.TypeContextUtil;
import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression;
import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression;
import io.confluent.ksql.execution.expression.tree.BetweenPredicate;
Expand Down Expand Up @@ -473,23 +474,30 @@ public Void visitFunctionCall(

final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName());

// this context gets updated as we process non lambda arguments
final TypeContext currentTypeContext = expressionTypeContext.getCopy();

final List<SqlArgument> argTypes = new ArrayList<>();

final boolean hasLambda = node.hasLambdaFunctionCallArguments();
for (final Expression expression : node.getArguments()) {
final TypeContext childContext = expressionTypeContext.getCopy();
final TypeContext childContext = TypeContextUtil.contextForExpression(
expression, expressionTypeContext, currentTypeContext
);
process(expression, childContext);
final SqlType resolvedArgType = childContext.getSqlType();

if (expression instanceof LambdaFunctionCall) {
argTypes.add(
SqlArgument.of(
SqlLambda.of(expressionTypeContext.getLambdaInputTypes(),
childContext.getSqlType())));
SqlLambda.of(
currentTypeContext.getLambdaInputTypes(),
resolvedArgType)));
} else {
argTypes.add(SqlArgument.of(resolvedArgType));
// for lambdas - we save the type information to resolve the lambda generics
if (hasLambda) {
expressionTypeContext.visitType(resolvedArgType);
currentTypeContext.visitType(resolvedArgType);
}
}
}
Expand Down

0 comments on commit 1a042cd

Please sign in to comment.