Skip to content

Commit

Permalink
tidy up
Browse files Browse the repository at this point in the history
  • Loading branch information
julianhyde committed Mar 21, 2023
1 parent 628d649 commit a39c6b8
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 138 deletions.
5 changes: 3 additions & 2 deletions core/src/main/codegen/templates/Parser.jj
Original file line number Diff line number Diff line change
Expand Up @@ -5999,9 +5999,10 @@ SqlNode BuiltinFunctionCall() :
{
//~ FUNCTIONS WITH SPECIAL SYNTAX ---------------------------------------
(
( <CAST> { s = span(); f = SqlStdOperatorTable.CAST; }
| <SAFE_CAST> { s = span(); f = SqlStdOperatorTable.SAFE_CAST; }
( <CAST> { f = SqlStdOperatorTable.CAST; }
| <SAFE_CAST> { f = SqlStdOperatorTable.SAFE_CAST; }
)
{ s = span(); }
<LPAREN> AddExpression(args, ExprContext.ACCEPT_SUB_QUERY)
<AS>
(
Expand Down
12 changes: 3 additions & 9 deletions core/src/main/java/org/apache/calcite/rex/RexSimplify.java
Original file line number Diff line number Diff line change
Expand Up @@ -2194,6 +2194,7 @@ && sameTypeOrNarrowsNullability(e.getType(), intExpr.getType())) {
return rexBuilder.makeCast(e.getType(), intExpr);
}
}
final boolean safe = e.getKind() == SqlKind.SAFE_CAST;
switch (operand.getKind()) {
case LITERAL:
final RexLiteral literal = (RexLiteral) operand;
Expand Down Expand Up @@ -2223,22 +2224,15 @@ && sameTypeOrNarrowsNullability(e.getType(), intExpr.getType())) {
}
final List<RexNode> reducedValues = new ArrayList<>();
final RexNode simplifiedExpr =
rexBuilder.makeCast(e.getType(),
operand,
e.getKind() == SqlKind.SAFE_CAST,
e.getKind() == SqlKind.SAFE_CAST);
rexBuilder.makeCast(e.getType(), operand, safe, safe);
executor.reduce(rexBuilder, ImmutableList.of(simplifiedExpr), reducedValues);
return requireNonNull(
Iterables.getOnlyElement(reducedValues));
default:
if (operand == e.getOperands().get(0)) {
return e;
} else {
return rexBuilder.makeCast(
e.getType(),
operand,
e.getKind() == SqlKind.SAFE_CAST,
e.getKind() == SqlKind.SAFE_CAST);
return rexBuilder.makeCast(e.getType(), operand, safe, safe);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/java/org/apache/calcite/sql/SqlKind.java
Original file line number Diff line number Diff line change
Expand Up @@ -639,8 +639,8 @@ public enum SqlKind {
*/
CAST,

/**
* The {@code SAFE_CAST} function. */
/** The {@code SAFE_CAST} function, which is similar to {@link #CAST} but
* returns NULL rather than throwing an error if the conversion fails. */
SAFE_CAST,

/**
Expand Down
22 changes: 14 additions & 8 deletions core/src/main/java/org/apache/calcite/sql/fun/SqlCastFunction.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.calcite.sql.fun;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeFamily;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCallBinding;
Expand Down Expand Up @@ -44,9 +45,10 @@
import com.google.common.collect.SetMultimap;

import java.text.Collator;
import java.util.Arrays;
import java.util.Objects;

import static com.google.common.base.Preconditions.checkArgument;

import static org.apache.calcite.util.Static.RESOURCE;

/**
Expand Down Expand Up @@ -96,20 +98,17 @@ public SqlCastFunction(SqlKind kind) {
InferTypes.FIRST_KNOWN,
null,
SqlFunctionCategory.SYSTEM);
assert Arrays.asList(SqlKind.CAST, SqlKind.SAFE_CAST).contains(kind);
checkArgument(kind == SqlKind.CAST || kind == SqlKind.SAFE_CAST, kind);
}

//~ Methods ----------------------------------------------------------------

@Override public RelDataType inferReturnType(
SqlOperatorBinding opBinding) {
assert opBinding.getOperandCount() == 2;
RelDataType ret = opBinding.getOperandType(1);
RelDataType firstType = opBinding.getOperandType(0);
ret =
opBinding.getTypeFactory().createTypeWithNullability(
ret,
firstType.isNullable() || this.kind == SqlKind.SAFE_CAST);
final RelDataType ret =
deriveType(opBinding.getTypeFactory(), opBinding.getOperandType(0),
opBinding.getOperandType(1), kind == SqlKind.SAFE_CAST);
if (opBinding instanceof SqlCallBinding) {
SqlCallBinding callBinding = (SqlCallBinding) opBinding;
SqlNode operand0 = callBinding.operand(0);
Expand All @@ -126,6 +125,13 @@ public SqlCastFunction(SqlKind kind) {
return ret;
}

/** Derives the type of "CAST(expression AS targetType)". */
public static RelDataType deriveType(RelDataTypeFactory typeFactory,
RelDataType expressionType, RelDataType targetType, boolean safe) {
return typeFactory.createTypeWithNullability(targetType,
expressionType.isNullable() || safe);
}

@Override public String getSignatureTemplate(final int operandsCount) {
assert operandsCount == 2;
return "{0}({1} AS {2})";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5454,7 +5454,7 @@ ImmutableList<RelNode> retrieveCursors() {
subQuery = requireNonNull(getSubQuery(expr, null));
rex = requireNonNull(subQuery.expr);
return StandardConvertletTable.castToValidatedType(expr, rex,
validator(), rexBuilder);
validator(), rexBuilder, false);

case SELECT:
case EXISTS:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.apache.calcite.sql.fun.SqlArrayValueConstructor;
import org.apache.calcite.sql.fun.SqlBetweenOperator;
import org.apache.calcite.sql.fun.SqlCase;
import org.apache.calcite.sql.fun.SqlCastFunction;
import org.apache.calcite.sql.fun.SqlDatetimeSubtractionOperator;
import org.apache.calcite.sql.fun.SqlExtractFunction;
import org.apache.calcite.sql.fun.SqlInternalOperators;
Expand All @@ -78,7 +79,6 @@
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorImpl;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;

Expand All @@ -91,12 +91,13 @@
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;

import static com.google.common.base.Preconditions.checkArgument;

import static org.apache.calcite.sql.type.NonNullableAccessors.getComponentTypeOrThrow;
import static org.apache.calcite.util.Util.first;

Expand Down Expand Up @@ -599,9 +600,13 @@ protected RexNode convertCast(
SqlRexContext cx,
final SqlCall call) {
RelDataTypeFactory typeFactory = cx.getTypeFactory();
assert Arrays.asList(SqlKind.CAST, SqlKind.SAFE_CAST).contains(call.getKind());
final SqlValidator validator = cx.getValidator();
final SqlKind kind = call.getKind();
checkArgument(kind == SqlKind.CAST || kind == SqlKind.SAFE_CAST, kind);
final boolean safe = kind == SqlKind.SAFE_CAST;
final SqlNode left = call.operand(0);
final SqlNode right = call.operand(1);
final RexBuilder rexBuilder = cx.getRexBuilder();
if (right instanceof SqlIntervalQualifier) {
final SqlIntervalQualifier intervalQualifier =
(SqlIntervalQualifier) right;
Expand All @@ -611,38 +616,34 @@ protected RexNode convertCast(
BigDecimal sourceValue =
(BigDecimal) sourceInterval.getValue();
RexLiteral castedInterval =
cx.getRexBuilder().makeIntervalLiteral(sourceValue,
rexBuilder.makeIntervalLiteral(sourceValue,
intervalQualifier);
return castToValidatedType(cx, call, castedInterval, call.getKind());
return castToValidatedType(call, castedInterval, validator, rexBuilder,
safe);
} else if (left instanceof SqlNumericLiteral) {
RexLiteral sourceInterval =
(RexLiteral) cx.convertExpression(left);
BigDecimal sourceValue =
(BigDecimal) sourceInterval.getValue();
requireNonNull(sourceInterval.getValueAs(BigDecimal.class),
"sourceValue");
final BigDecimal multiplier = intervalQualifier.getUnit().multiplier;
sourceValue = SqlFunctions.multiply(sourceValue, multiplier);
RexLiteral castedInterval =
cx.getRexBuilder().makeIntervalLiteral(
sourceValue,
rexBuilder.makeIntervalLiteral(
SqlFunctions.multiply(sourceValue, multiplier),
intervalQualifier);
return castToValidatedType(cx, call, castedInterval, call.getKind());
return castToValidatedType(call, castedInterval, validator, rexBuilder,
safe);
}
return castToValidatedType(cx, call, cx.convertExpression(left), call.getKind());
RexNode value = cx.convertExpression(left);
return castToValidatedType(call, value, validator, rexBuilder, safe);
}
SqlDataTypeSpec dataType = (SqlDataTypeSpec) right;
// If SAFE_CAST, allow nullable.

final RexNode arg = cx.convertExpression(left);
final SqlDataTypeSpec dataType = (SqlDataTypeSpec) right;
RelDataType type =
// dataType.deriveType(cx.getValidator(), call.getKind() == SqlKind.SAFE_CAST);
dataType.deriveType(cx.getValidator());
if (type == null) {
type = cx.getValidator().getValidatedNodeType(dataType.getTypeName());
}
RexNode arg = cx.convertExpression(left);
if (arg.getType().isNullable() || call.getKind() == SqlKind.SAFE_CAST) {
type = typeFactory.createTypeWithNullability(type, true);
}
SqlCastFunction.deriveType(cx.getTypeFactory(), arg.getType(),
dataType.deriveType(validator), safe);
if (SqlUtil.isNullLiteral(left, false)) {
final SqlValidatorImpl validator = (SqlValidatorImpl) cx.getValidator();
validator.setValidatedNodeType(left, type);
return cx.convertExpression(left);
}
Expand All @@ -651,7 +652,7 @@ protected RexNode convertCast(

// arg.getType() may be ANY
if (argComponentType == null) {
argComponentType = dataType.getComponentTypeSpec().deriveType(cx.getValidator());
argComponentType = dataType.getComponentTypeSpec().deriveType(validator);
}

requireNonNull(argComponentType, () -> "componentType of " + arg);
Expand All @@ -673,11 +674,7 @@ protected RexNode convertCast(
type = typeFactory.createTypeWithNullability(type, isn);
}
}
return cx.getRexBuilder().makeCast(
type,
arg,
call.getKind() == SqlKind.SAFE_CAST,
call.getKind() == SqlKind.SAFE_CAST);
return rexBuilder.makeCast(type, arg, safe, safe);
}

protected RexNode convertFloorCeil(SqlRexContext cx, SqlCall call) {
Expand Down Expand Up @@ -1341,47 +1338,20 @@ private static Pair<RexNode, RexNode> convertOverlapsOperand(SqlRexContext cx,
return Pair.of(r0, r1);
}

/**
* Casts a RexNode value to the validated type of a SqlCall. If the value
* was already of the validated type, then the value is returned without an
* additional cast.
*/
@Deprecated // to be removed before 2.0
public RexNode castToValidatedType(
@UnknownInitialization StandardConvertletTable this,
SqlRexContext cx,
SqlCall call,
RexNode value) {
return castToValidatedType(call, value, cx.getValidator(),
cx.getRexBuilder());
}

/**
* Casts a RexNode value to the validated type of a SqlCall. If the value
* was already of the validated type, then the value is returned without an
* additional cast.
*/
public RexNode castToValidatedType(
@UnknownInitialization StandardConvertletTable this,
SqlRexContext cx,
SqlCall call,
RexNode value,
SqlKind kind) {
return castToValidatedType(call, value, cx.getValidator(),
cx.getRexBuilder(), kind);
cx.getRexBuilder(), false);
}

/**
* Casts a RexNode value to the validated type of a SqlCall. If the value
* was already of the validated type, then the value is returned without an
* additional cast.
*/
@Deprecated // to be removed before 2.0
public static RexNode castToValidatedType(SqlNode node, RexNode e,
SqlValidator validator, RexBuilder rexBuilder) {
final RelDataType type = validator.getValidatedNodeType(node);
if (e.getType() == type) {
return e;
}
return rexBuilder.makeCast(type, e);
return castToValidatedType(node, e, validator, rexBuilder, false);
}

/**
Expand All @@ -1390,13 +1360,12 @@ public static RexNode castToValidatedType(SqlNode node, RexNode e,
* additional cast.
*/
public static RexNode castToValidatedType(SqlNode node, RexNode e,
SqlValidator validator, RexBuilder rexBuilder, SqlKind kind) {
SqlValidator validator, RexBuilder rexBuilder, boolean safe) {
final RelDataType type = validator.getValidatedNodeType(node);
if (e.getType() == type) {
return e;
}
return rexBuilder.makeCast(type, e,
kind == SqlKind.SAFE_CAST, kind == SqlKind.SAFE_CAST);
return rexBuilder.makeCast(type, e, safe, safe);
}

/** Convertlet that handles {@code COVAR_POP}, {@code COVAR_SAMP},
Expand Down
20 changes: 20 additions & 0 deletions core/src/test/resources/sql/winagg.iq
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,26 @@ from emp;

!ok

# STDDEV applied to nullable column
select empno,
stddev(comm) over (order by empno rows unbounded preceding) as stdev
from emp
where deptno = 30
order by 1;
+-------+-------------------------------------------------+
| EMPNO | STDEV |
+-------+-------------------------------------------------+
| 7499 | |
| 7521 | 141.421356237309510106570087373256683349609375 |
| 7654 | 585.9465277082316561063635163009166717529296875 |
| 7698 | 585.9465277082316561063635163009166717529296875 |
| 7844 | 602.7713773341707792496890760958194732666015625 |
| 7900 | 602.7713773341707792496890760958194732666015625 |
+-------+-------------------------------------------------+
(6 rows)

!ok

!use post
# [CALCITE-1540] Support multiple columns in PARTITION BY clause of window function
select gender,deptno,
Expand Down
2 changes: 1 addition & 1 deletion site/_docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -2705,7 +2705,7 @@ BigQuery's type system uses confusingly different names for types and functions:
| h s | string1 NOT RLIKE string2 | Whether *string1* does not match regex pattern *string2* (similar to `NOT LIKE`, but uses Java regex)
| b o | RPAD(string, length[, pattern ]) | Returns a string or bytes value that consists of *string* appended to *length* with *pattern*
| b o | RTRIM(string) | Returns *string* with all blanks removed from the end
| b | SAFE_CAST(value AS type) | Converts a value to a given type. Returns *null* instead of raising an error.
| b | SAFE_CAST(value AS type) | Converts *value* to *type*, returning NULL if conversion fails
| b m p | SHA1(string) | Calculates a SHA-1 hash value of *string* and returns it as a hex string
| b o | SINH(numeric) | Returns the hyperbolic sine of *numeric*
| b m o p | SOUNDEX(string) | Returns the phonetic representation of *string*; throws if *string* is encoded with multi-byte encoding such as UTF-8
Expand Down

0 comments on commit a39c6b8

Please sign in to comment.