diff --git a/core/src/main/java/org/opensearch/sql/calcite/udf/datetimeUDF/TimeAddSubFunction.java b/core/src/main/java/org/opensearch/sql/calcite/udf/datetimeUDF/TimeAddSubFunction.java index 8809c14da7c..cecfa9d2309 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/udf/datetimeUDF/TimeAddSubFunction.java +++ b/core/src/main/java/org/opensearch/sql/calcite/udf/datetimeUDF/TimeAddSubFunction.java @@ -5,13 +5,20 @@ package org.opensearch.sql.calcite.udf.datetimeUDF; +import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_DATE; +import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_TIME; +import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_TIMESTAMP; import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.restoreFunctionProperties; import static org.opensearch.sql.calcite.utils.datetime.DateTimeApplyUtils.transferInputToExprValue; import static org.opensearch.sql.expression.datetime.DateTimeFunctions.exprAddTime; import static org.opensearch.sql.expression.datetime.DateTimeFunctions.exprSubTime; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.type.SqlTypeName; +import org.opensearch.sql.calcite.type.ExprSqlType; import org.opensearch.sql.calcite.udf.UserDefinedFunction; +import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; import org.opensearch.sql.data.model.ExprTimeValue; import org.opensearch.sql.data.model.ExprValue; @@ -44,4 +51,35 @@ public Object eval(Object... args) { return result.valueForCalcite(); } } + + /** + * ADDTIME and SUBTIME has special return type maps: (DATE/TIMESTAMP, DATE/TIMESTAMP/TIME) -> + * TIMESTAMP (TIME, DATE/TIMESTAMP/TIME) -> TIME Therefore, we create a special return type + * inference for them. + */ + public static SqlReturnTypeInference getReturnTypeForTimeAddSub() { + return opBinding -> { + RelDataType operandType0 = opBinding.getOperandType(0); + if (operandType0 instanceof ExprSqlType) { + OpenSearchTypeFactory.ExprUDT exprUDT = ((ExprSqlType) operandType0).getUdt(); + if (exprUDT == EXPR_DATE || exprUDT == EXPR_TIMESTAMP) { + return UserDefinedFunctionUtils.nullableTimestampUDT; + } else if (exprUDT == EXPR_TIME) { + return UserDefinedFunctionUtils.nullableTimeUDT; + } else { + throw new IllegalArgumentException("Unsupported UDT type"); + } + } + SqlTypeName typeName = operandType0.getSqlTypeName(); + return switch (typeName) { + case DATE, TIMESTAMP -> + // Return TIMESTAMP + UserDefinedFunctionUtils.nullableTimestampUDT; + case TIME -> + // Return TIME + UserDefinedFunctionUtils.nullableTimeUDT; + default -> throw new IllegalArgumentException("Unsupported type: " + typeName); + }; + }; + } } diff --git a/core/src/main/java/org/opensearch/sql/calcite/udf/mathUDF/DivideFunction.java b/core/src/main/java/org/opensearch/sql/calcite/udf/mathUDF/DivideFunction.java new file mode 100644 index 00000000000..3f5a5dcc670 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/calcite/udf/mathUDF/DivideFunction.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.calcite.udf.mathUDF; + +import org.opensearch.sql.calcite.udf.UserDefinedFunction; +import org.opensearch.sql.calcite.utils.MathUtils; +import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils; + +public class DivideFunction implements UserDefinedFunction { + + @Override + public Object eval(Object... args) { + if (UserDefinedFunctionUtils.containsNull(args)) { + return null; + } + + Number dividend = (Number) args[0]; + Number divisor = (Number) args[1]; + + if (divisor.doubleValue() == 0) { + return null; + } + + if (MathUtils.isIntegral(dividend) && MathUtils.isIntegral(divisor)) { + long result = dividend.longValue() / divisor.longValue(); + return MathUtils.coerceToWidestIntegralType(dividend, divisor, result); + } + double result = dividend.doubleValue() / divisor.doubleValue(); + return MathUtils.coerceToWidestFloatingType(dividend, divisor, result); + } +} diff --git a/core/src/main/java/org/opensearch/sql/calcite/udf/mathUDF/ModFunction.java b/core/src/main/java/org/opensearch/sql/calcite/udf/mathUDF/ModFunction.java index 3c1d2817108..eac19d88791 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/udf/mathUDF/ModFunction.java +++ b/core/src/main/java/org/opensearch/sql/calcite/udf/mathUDF/ModFunction.java @@ -7,6 +7,7 @@ import java.math.BigDecimal; import org.opensearch.sql.calcite.udf.UserDefinedFunction; +import org.opensearch.sql.calcite.utils.MathUtils; /** * Calculate the remainder of x divided by y
@@ -32,33 +33,22 @@ public Object eval(Object... args) { arg0.getClass().getSimpleName(), arg1.getClass().getSimpleName())); } - // TODO: This precision check is arbitrary. - if (Math.abs(num1.doubleValue()) < 0.0000001) { + if (num1.doubleValue() == 0) { return null; } - if (isIntegral(num0) && isIntegral(num1)) { + if (MathUtils.isIntegral(num0) && MathUtils.isIntegral(num1)) { long l0 = num0.longValue(); long l1 = num1.longValue(); // It returns negative values when l0 is negative long result = l0 % l1; // Return the wider type between l0 and l1 - if (num0 instanceof Long || num1 instanceof Long) { - return result; - } - return (int) result; + return MathUtils.coerceToWidestIntegralType(num0, num1, result); } BigDecimal b0 = new BigDecimal(num0.toString()); BigDecimal b1 = new BigDecimal(num1.toString()); BigDecimal result = b0.remainder(b1); - if (num0 instanceof Double || num1 instanceof Double) { - return result.doubleValue(); - } - return result.floatValue(); - } - - private boolean isIntegral(Number n) { - return n instanceof Byte || n instanceof Short || n instanceof Integer || n instanceof Long; + return MathUtils.coerceToWidestFloatingType(num0, num1, result.doubleValue()); } } diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/BuiltinFunctionUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/BuiltinFunctionUtils.java index fdd7738ca7d..8cbfd6d1723 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/BuiltinFunctionUtils.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/BuiltinFunctionUtils.java @@ -6,7 +6,6 @@ package org.opensearch.sql.calcite.utils; import static java.lang.Math.E; -import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.*; import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.getLegacyTypeName; import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.*; @@ -90,6 +89,7 @@ import org.opensearch.sql.calcite.udf.datetimeUDF.YearWeekFunction; import org.opensearch.sql.calcite.udf.mathUDF.CRC32Function; import org.opensearch.sql.calcite.udf.mathUDF.ConvFunction; +import org.opensearch.sql.calcite.udf.mathUDF.DivideFunction; import org.opensearch.sql.calcite.udf.mathUDF.EulerFunction; import org.opensearch.sql.calcite.udf.mathUDF.ModFunction; import org.opensearch.sql.calcite.udf.mathUDF.SqrtFunction; @@ -138,7 +138,8 @@ static SqlOperator translate(String op) { case "*": return SqlStdOperatorTable.MULTIPLY; case "/": - return SqlStdOperatorTable.DIVIDE; + return TransferUserDefinedFunction( + DivideFunction.class, "/", ReturnTypes.QUOTIENT_NULLABLE); // Built-in String Functions case "ASCII": return SqlStdOperatorTable.ASCII; @@ -216,8 +217,7 @@ static SqlOperator translate(String op) { // The MOD function in PPL supports floating-point parameters, e.g., MOD(5.5, 2) = 1.5, // MOD(3.1, 2.1) = 1.1, // whereas SqlStdOperatorTable.MOD supports only integer / long parameters. - return TransferUserDefinedFunction( - ModFunction.class, "MOD", getLeastRestrictiveReturnTypeAmongArgsAt(List.of(0, 1))); + return TransferUserDefinedFunction(ModFunction.class, "MOD", ReturnTypes.LEAST_RESTRICTIVE); case "PI": return SqlStdOperatorTable.PI; case "POW", "POWER": @@ -265,9 +265,7 @@ static SqlOperator translate(String op) { DateAddSubFunction.class, "DATE_SUB", timestampInference); case "ADDTIME", "SUBTIME": return TransferUserDefinedFunction( - TimeAddSubFunction.class, - capitalOP, - UserDefinedFunctionUtils.getReturnTypeForTimeAddSub()); + TimeAddSubFunction.class, capitalOP, TimeAddSubFunction.getReturnTypeForTimeAddSub()); case "DAY_OF_WEEK", "DAYOFWEEK": return TransferUserDefinedFunction( DayOfWeekFunction.class, capitalOP, INTEGER_FORCE_NULLABLE); diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/MathUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/MathUtils.java new file mode 100644 index 00000000000..b211dd9223a --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/MathUtils.java @@ -0,0 +1,52 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.calcite.utils; + +public class MathUtils { + public static boolean isIntegral(Number n) { + return n instanceof Byte || n instanceof Short || n instanceof Integer || n instanceof Long; + } + + /** + * Converts a long value to the least restrictive integral type based on the types of two input + * numbers. + * + *

This is useful when performing operations like division or modulo and you want to preserve + * the most appropriate type (e.g., int vs long). + * + * @param a one operand involved in the operation + * @param b another operand involved in the operation + * @param value the result to convert to the least restrictive integral type + * @return the value converted to Byte, Short, Integer, or Long + */ + public static Number coerceToWidestIntegralType(Number a, Number b, long value) { + if (a instanceof Long || b instanceof Long) { + return value; + } else if (a instanceof Integer || b instanceof Integer) { + return (int) value; + } else if (a instanceof Short || b instanceof Short) { + return (short) value; + } else { + return (byte) value; + } + } + + /** + * Converts a double value to the least restrictive floating type based on the types of two input + * + * @param a one operand involved in the operation + * @param b another operand involved in the operation + * @param value the result to convert to the least restrictive floating type + * @return the value converted to Float or Double + */ + public static Number coerceToWidestFloatingType(Number a, Number b, double value) { + if (a instanceof Double || b instanceof Double) { + return value; + } else { + return (float) value; + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java index 3758d1c73fa..176de3474a1 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java @@ -104,37 +104,6 @@ public static SqlOperator TransferUserDefinedFunction( udfFunction); } - /** - * Infer return argument type as the widest return type among arguments as specified positions. - * E.g. (Integer, Long) -> Long; (Double, Float, SHORT) -> Double - * - * @param positions positions where the return type should be inferred from - * @param nullable whether the returned value is nullable - * @return The type inference - */ - public static SqlReturnTypeInference getLeastRestrictiveReturnTypeAmongArgsAt( - List positions) { - return opBinding -> { - RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); - List types = new ArrayList<>(); - - for (int position : positions) { - if (position < 0 || position >= opBinding.getOperandCount()) { - throw new IllegalArgumentException("Invalid argument position: " + position); - } - types.add(opBinding.getOperandType(position)); - } - - RelDataType widerType = typeFactory.leastRestrictive(types); - if (widerType == null) { - throw new IllegalArgumentException( - "Cannot determine a common type for the given positions."); - } - - return widerType; - }; - } - static SqlReturnTypeInference getReturnTypeInferenceForArray() { return opBinding -> { RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); @@ -150,37 +119,6 @@ static SqlReturnTypeInference getReturnTypeInferenceForArray() { }; } - /** - * ADDTIME and SUBTIME has special return type maps: (DATE/TIMESTAMP, DATE/TIMESTAMP/TIME) -> - * TIMESTAMP (TIME, DATE/TIMESTAMP/TIME) -> TIME Therefore, we create a special return type - * inference for them. - */ - static SqlReturnTypeInference getReturnTypeForTimeAddSub() { - return opBinding -> { - RelDataType operandType0 = opBinding.getOperandType(0); - if (operandType0 instanceof ExprSqlType) { - ExprUDT exprUDT = ((ExprSqlType) operandType0).getUdt(); - if (exprUDT == EXPR_DATE || exprUDT == EXPR_TIMESTAMP) { - return nullableTimestampUDT; - } else if (exprUDT == EXPR_TIME) { - return nullableTimeUDT; - } else { - throw new IllegalArgumentException("Unsupported UDT type"); - } - } - SqlTypeName typeName = operandType0.getSqlTypeName(); - return switch (typeName) { - case DATE, TIMESTAMP -> - // Return TIMESTAMP - nullableTimestampUDT; - case TIME -> - // Return TIME - nullableTimeUDT; - default -> throw new IllegalArgumentException("Unsupported type: " + typeName); - }; - }; - } - static List transferStringExprToDateValue(String timeExpr) { try { if (timeExpr.contains(":")) { diff --git a/docs/user/ppl/functions/expressions.rst b/docs/user/ppl/functions/expressions.rst index d25063d559c..3c4b2a5f1a2 100644 --- a/docs/user/ppl/functions/expressions.rst +++ b/docs/user/ppl/functions/expressions.rst @@ -28,7 +28,7 @@ Arithmetic expression is an expression formed by numeric literals and binary ari 1. ``+``: Add. 2. ``-``: Subtract. 3. ``*``: Multiply. -4. ``/``: Divide. For integers, the result is an integer with fractional part discarded. +4. ``/``: Divide. For integers, the result is an integer with fractional part discarded. Returns NULL when dividing by zero. 5. ``%``: Modulo. This can be used with integers only with remainder of the division as result. Precedence diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLBuiltinFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLBuiltinFunctionIT.java index c5a03757a83..6b1b33e09cc 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLBuiltinFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLBuiltinFunctionIT.java @@ -7,6 +7,7 @@ import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_DATATYPE_NUMERIC; import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_DOG; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_NULL_MISSING; import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_STATE_COUNTRY; import static org.opensearch.sql.util.MatcherUtils.closeTo; import static org.opensearch.sql.util.MatcherUtils.rows; @@ -28,6 +29,7 @@ public void init() throws IOException { loadIndex(Index.STATE_COUNTRY_WITH_NULL); loadIndex(Index.DATA_TYPE_NUMERIC); loadIndex(Index.DOG); + loadIndex(Index.NULL_MISSING); } @Test @@ -291,8 +293,8 @@ public void testModShouldReturnWiderTypes() { executeQuery( String.format( "source=%s | eval b = byte_number %% 2, i = mod(integer_number, 3), l =" - + " mod(long_number, 2), f = float_number %% 2, d = mod(double_number, 2) |" - + " fields b, i, l, f, d", + + " mod(long_number, 2), f = float_number %% 2, d = mod(double_number, 2), s =" + + " short_number %% byte_number | fields b, i, l, f, d, s", TEST_INDEX_DATATYPE_NUMERIC)); verifySchema( actual, @@ -300,8 +302,9 @@ public void testModShouldReturnWiderTypes() { schema("i", "integer"), schema("l", "long"), schema("f", "float"), - schema("d", "double")); - verifyDataRows(actual, closeTo(0, 2, 1, 0.2, 1.1)); + schema("d", "double"), + schema("s", "short")); + verifyDataRows(actual, closeTo(0, 2, 1, 0.2, 1.1, 3)); } @Test @@ -377,4 +380,62 @@ public void testSignAndRound() { actual, schema("name", "string"), schema("age", "integer"), schema("thirty_one", "double")); verifyDataRows(actual, rows("Hello", 30, 31)); } + + @Test + public void testDivide() { + JSONObject actual = + executeQuery( + String.format( + "source=%s | eval r1 = 22 / 7, r2 = integer_number / 1, r3 = 21 / 7, r4 =" + + " byte_number / short_number, r5 = half_float_number / float_number, r6 =" + + " float_number / short_number, r7 = 22 / 7.0, r8 = 22.0 / 7, r9 = 21.0 / 7.0," + + " r10 = half_float_number / short_number, r11 = double_number / float_number" + + " | fields r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11", + TEST_INDEX_DATATYPE_NUMERIC)); + verifySchema( + actual, + schema("r1", "integer"), + schema("r2", "integer"), + schema("r3", "integer"), + schema("r4", "short"), + schema("r5", "float"), + schema("r6", "float"), + schema("r7", "double"), + schema("r8", "double"), + schema("r9", "double"), + schema("r10", "float"), + schema("r11", "double")); + verifyDataRows( + actual, + closeTo( + 3, + 2, + 3, + 1, + 1.1774194, + 2.0666666, + 3.142857142857143, + 3.142857142857143, + 3.0, + 2.4333334, + 0.8225806704669051)); + } + + @Test + public void testDivideShouldReturnNull() { + JSONObject actual = + executeQuery( + String.format( + "source=%s | where key = 'null' | head 1 | eval r2 = 4 / dbl, r3 = `int` / 5, r4 =" + + " 22 / 0, r5 = 22.0 / 0, r6 = 22.0 / 0.0 | fields r2, r3, r4, r5, r6", + TEST_INDEX_NULL_MISSING)); + verifySchema( + actual, + schema("r2", "double"), + schema("r3", "integer"), + schema("r4", "integer"), + schema("r5", "double"), + schema("r6", "double")); + verifyDataRows(actual, rows(null, null, null, null, null)); + } }