Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: improve handling of NULLs #5019

Merged
merged 7 commits into from Apr 14, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -25,6 +25,7 @@
import io.confluent.ksql.execution.codegen.CodeGenRunner;
import io.confluent.ksql.execution.codegen.ExpressionMetadata;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.NullLiteral;
import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.logging.processing.NoopProcessingLogContext;
Expand Down Expand Up @@ -588,8 +589,13 @@ protected Object visitExpression(final Expression expression, final Void context
fieldName,
valueSqlType,
value));
})
.orElse(null);
}

});
@Override
public Object visitNullLiteral(final NullLiteral node, final Void context) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This avoids any compiling of code etc to handle a NULL literal.

return null;
}
}
}
Expand Up @@ -112,7 +112,10 @@ public ExpressionMetadata buildCodeGenFromParseTree(
final SqlType expressionType = expressionTypeManager
.getExpressionSqlType(expression);

ee.setExpressionType(SQL_TO_JAVA_TYPE_CONVERTER.toJavaType(expressionType));
if (expressionType != null) {
// expressionType can be null if expression is NULL.
ee.setExpressionType(SQL_TO_JAVA_TYPE_CONVERTER.toJavaType(expressionType));
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This avoids an NPE. If not called the return value is left at the default Object, which is fine for nulls.


ee.cook(javaCode);

Expand All @@ -123,8 +126,7 @@ public ExpressionMetadata buildCodeGenFromParseTree(
expression
);
} catch (KsqlException | CompileException e) {
throw new KsqlException("Code generation failed for " + type
+ ": " + e.getMessage()
throw new KsqlException("Invalid " + type + ": " + e.getMessage()
+ ". expression:" + expression + ", schema:" + schema, e);
} catch (final Exception e) {
throw new RuntimeException("Unexpected error generating code for " + type
Expand Down
Expand Up @@ -26,6 +26,7 @@
import com.google.common.collect.Multiset;
import io.confluent.ksql.execution.codegen.helpers.ArrayAccess;
import io.confluent.ksql.execution.codegen.helpers.ArrayBuilder;
import io.confluent.ksql.execution.codegen.helpers.MapBuilder;
import io.confluent.ksql.execution.codegen.helpers.SearchedCaseFunction;
import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression;
import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression;
Expand Down Expand Up @@ -124,7 +125,8 @@ public class SqlToJavaVisitor {
RoundingMode.class.getCanonicalName(),
SchemaBuilder.class.getCanonicalName(),
Struct.class.getCanonicalName(),
ArrayBuilder.class.getCanonicalName()
ArrayBuilder.class.getCanonicalName(),
MapBuilder.class.getCanonicalName()
);

private static final Map<Operator, String> DECIMAL_OPERATOR_NAME = ImmutableMap
Expand Down Expand Up @@ -834,7 +836,9 @@ public Pair<String, SqlType> visitCreateMapExpression(
final CreateMapExpression exp,
final Void context
) {
final StringBuilder map = new StringBuilder("ImmutableMap.builder()");
final StringBuilder map = new StringBuilder("new MapBuilder(");
map.append(exp.getMap().size());
map.append((')'));
Comment on lines +818 to +820
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch to a new builder type that won't throw on null keys or values.


for (Entry<Expression, Expression> entry: exp.getMap().entrySet()) {
map.append(".put(");
Expand Down Expand Up @@ -925,27 +929,25 @@ private CastVisitor() {
}

static Pair<String, SqlType> getCast(final Pair<String, SqlType> expr, final SqlType sqlType) {
if (!sqlType.supportsCast()) {
throw new KsqlFunctionException(
"Only casts to primitive types and decimal are supported: " + sqlType);
}

final SqlType rightSchema = expr.getRight();
if (sqlType.equals(rightSchema) || rightSchema == null) {
final SqlType sourceType = expr.getRight();
if (sourceType == null || sqlType.equals(sourceType)) {
// sourceType is null if source is SQL NULL
return new Pair<>(expr.getLeft(), sqlType);
}

return CASTERS.getOrDefault(
sqlType.baseType(),
(e, t, r) -> {
throw new KsqlException("Invalid cast operation: " + t);
}
)
.cast(expr, sqlType, sqlType);
return CASTERS.getOrDefault(sqlType.baseType(), CastVisitor::unsupportedCast)
.cast(expr, sqlType);
}
Comment on lines +907 to +915
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getCast no longer calls sqlType.supportsCast as it is this the getCast method that determines which casts are supported and which are not. Hence supportsCast is superfluous and likely to get out of date with what this method supports.


private static Pair<String, SqlType> unsupportedCast(
final Pair<String, SqlType> expr, final SqlType returnType
) {
throw new KsqlFunctionException("Cast of " + expr.getRight()
+ " to " + returnType + " is not supported");
}

private static Pair<String, SqlType> castString(
final Pair<String, SqlType> expr, final SqlType sqltype, final SqlType returnType
final Pair<String, SqlType> expr, final SqlType returnType
) {
final SqlType schema = expr.getRight();
final String exprStr;
Expand All @@ -961,13 +963,13 @@ private static Pair<String, SqlType> castString(
}

private static Pair<String, SqlType> castBoolean(
final Pair<String, SqlType> expr, final SqlType sqltype, final SqlType returnType
final Pair<String, SqlType> expr, final SqlType returnType
) {
return new Pair<>(getCastToBooleanString(expr.getRight(), expr.getLeft()), returnType);
}

private static Pair<String, SqlType> castInteger(
final Pair<String, SqlType> expr, final SqlType sqltype, final SqlType returnType
final Pair<String, SqlType> expr, final SqlType returnType
) {
final String exprStr = getCastString(
expr.getRight(),
Expand All @@ -979,7 +981,7 @@ private static Pair<String, SqlType> castInteger(
}

private static Pair<String, SqlType> castLong(
final Pair<String, SqlType> expr, final SqlType sqltype, final SqlType returnType
final Pair<String, SqlType> expr, final SqlType returnType
) {
final String exprStr = getCastString(
expr.getRight(),
Expand All @@ -991,7 +993,7 @@ private static Pair<String, SqlType> castLong(
}

private static Pair<String, SqlType> castDouble(
final Pair<String, SqlType> expr, final SqlType sqltype, final SqlType returnType
final Pair<String, SqlType> expr, final SqlType returnType
) {
final String exprStr = getCastString(
expr.getRight(),
Expand All @@ -1003,13 +1005,13 @@ private static Pair<String, SqlType> castDouble(
}

private static Pair<String, SqlType> castDecimal(
final Pair<String, SqlType> expr, final SqlType sqltype, final SqlType returnType
final Pair<String, SqlType> expr, final SqlType returnType
) {
if (!(sqltype instanceof SqlDecimal)) {
throw new KsqlException("Expected decimal type: " + sqltype);
if (!(returnType instanceof SqlDecimal)) {
throw new KsqlException("Expected decimal type: " + returnType);
Comment on lines +983 to +986
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: these methods were always called with the same param for returnType and sqlType: removed the duplication.

}

final SqlDecimal sqlDecimal = (SqlDecimal) sqltype;
final SqlDecimal sqlDecimal = (SqlDecimal) returnType;

if (expr.getRight().baseType() == SqlBaseType.DECIMAL && expr.right.equals(sqlDecimal)) {
return expr;
Expand Down Expand Up @@ -1049,7 +1051,6 @@ private static String getCastString(
return "(new Double(" + exprStr + ")." + javaTypeMethod + ")";
case STRING:
return javaStringParserMethod + "(" + exprStr + ")";

default:
throw new KsqlFunctionException(
"Invalid cast operation: Cannot cast "
Expand Down Expand Up @@ -1086,7 +1087,6 @@ private interface CastFunction {

Pair<String, SqlType> cast(
Pair<String, SqlType> expr,
SqlType sqltype,
SqlType returnType
);
}
Expand All @@ -1104,5 +1104,4 @@ private CaseWhenProcessed(
this.thenProcessResult = thenProcessResult;
}
}

}
@@ -0,0 +1,42 @@
/*
* Copyright 2020 Confluent Inc.
*
* Licensed under the Confluent Community License (the "License"; you may not use
* this file except in compliance with the License. You may obtain a copy of the
* License at
*
* http://www.confluent.io/confluent-community-license
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/

package io.confluent.ksql.execution.codegen.helpers;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

/**
* Used to construct maps using the builder pattern. Note that we cannot use {@link
* com.google.common.collect.ImmutableMap} because it does not accept null values.
*/
public class MapBuilder {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Builder of maps with nulls


private final HashMap<Object, Object> map;

public MapBuilder(final int size) {
map = new HashMap<>(size);
}

public MapBuilder put(final Object key, final Object value) {
map.put(key, value);
return this;
}

public Map<Object, Object> build() {
return Collections.unmodifiableMap(map);
}
}
Expand Up @@ -55,7 +55,6 @@
import io.confluent.ksql.function.AggregateFunctionInitArguments;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.function.KsqlAggregateFunction;
import io.confluent.ksql.function.KsqlFunctionException;
import io.confluent.ksql.function.KsqlTableFunction;
import io.confluent.ksql.function.UdfFactory;
import io.confluent.ksql.schema.ksql.Column;
Expand Down Expand Up @@ -146,13 +145,7 @@ public Void visitNotExpression(

@Override
public Void visitCast(final Cast node, final ExpressionTypeContext expressionTypeContext) {
final SqlType sqlType = node.getType().getSqlType();
if (!sqlType.supportsCast()) {
throw new KsqlFunctionException("Only casts to primitive types or decimals "
+ "are supported: " + sqlType);
}

expressionTypeContext.setSqlType(sqlType);
expressionTypeContext.setSqlType(node.getType().getSqlType());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this as its not needed. It's getCast in SqlToJavaVisitor that determines what casts are supported. Removing this check just means it fails, with the same error, slightly later in flow.

return null;
}

Expand Down Expand Up @@ -404,7 +397,11 @@ public Void visitCreateMapExpression(
.collect(Collectors.toList());

if (keyTypes.stream().anyMatch(type -> !SqlTypes.STRING.equals(type))) {
throw new KsqlException("Only STRING keys are supported in maps but got: " + keyTypes);
final String types = keyTypes.stream()
.map(type -> type == null ? "NULL" : type.toString())
.collect(Collectors.joining(", ", "[", "]"));

throw new KsqlException("Only STRING keys are supported in maps but got: " + types);
}

final List<SqlType> valueTypes = exp.getMap()
Expand All @@ -414,9 +411,16 @@ public Void visitCreateMapExpression(
process(val, context);
return context.getSqlType();
})
.filter(Objects::nonNull)
.distinct()
.collect(Collectors.toList());

if (valueTypes.size() == 0) {
throw new KsqlException("Cannot construct a map with all NULL values "
+ "(see https://github.com/confluentinc/ksql/issues/4239). As a workaround, you may "
+ "cast a NULL value to the desired type.");
}

if (valueTypes.size() != 1) {
throw new KsqlException(
String.format(
Expand All @@ -425,11 +429,6 @@ public Void visitCreateMapExpression(
exp));
}

if (valueTypes.get(0) == null) {
throw new KsqlException("Cannot construct MAP with NULL values. As a workaround, you "
+ "may cast a NULL value to the desired type.");
}

context.setSqlType(SqlMap.of(valueTypes.get(0)));
return null;
}
Expand Down