diff --git a/core/src/main/java/org/opensearch/sql/calcite/udf/datetimeUDF/TimeAddSubFunction.java b/core/src/main/java/org/opensearch/sql/calcite/udf/datetimeUDF/TimeAddSubFunction.java
index 8809c14da7c..cecfa9d2309 100644
--- a/core/src/main/java/org/opensearch/sql/calcite/udf/datetimeUDF/TimeAddSubFunction.java
+++ b/core/src/main/java/org/opensearch/sql/calcite/udf/datetimeUDF/TimeAddSubFunction.java
@@ -5,13 +5,20 @@
package org.opensearch.sql.calcite.udf.datetimeUDF;
+import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_DATE;
+import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_TIME;
+import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_TIMESTAMP;
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.restoreFunctionProperties;
import static org.opensearch.sql.calcite.utils.datetime.DateTimeApplyUtils.transferInputToExprValue;
import static org.opensearch.sql.expression.datetime.DateTimeFunctions.exprAddTime;
import static org.opensearch.sql.expression.datetime.DateTimeFunctions.exprSubTime;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeName;
+import org.opensearch.sql.calcite.type.ExprSqlType;
import org.opensearch.sql.calcite.udf.UserDefinedFunction;
+import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
import org.opensearch.sql.data.model.ExprTimeValue;
import org.opensearch.sql.data.model.ExprValue;
@@ -44,4 +51,35 @@ public Object eval(Object... args) {
return result.valueForCalcite();
}
}
+
+ /**
+ * ADDTIME and SUBTIME has special return type maps: (DATE/TIMESTAMP, DATE/TIMESTAMP/TIME) ->
+ * TIMESTAMP (TIME, DATE/TIMESTAMP/TIME) -> TIME Therefore, we create a special return type
+ * inference for them.
+ */
+ public static SqlReturnTypeInference getReturnTypeForTimeAddSub() {
+ return opBinding -> {
+ RelDataType operandType0 = opBinding.getOperandType(0);
+ if (operandType0 instanceof ExprSqlType) {
+ OpenSearchTypeFactory.ExprUDT exprUDT = ((ExprSqlType) operandType0).getUdt();
+ if (exprUDT == EXPR_DATE || exprUDT == EXPR_TIMESTAMP) {
+ return UserDefinedFunctionUtils.nullableTimestampUDT;
+ } else if (exprUDT == EXPR_TIME) {
+ return UserDefinedFunctionUtils.nullableTimeUDT;
+ } else {
+ throw new IllegalArgumentException("Unsupported UDT type");
+ }
+ }
+ SqlTypeName typeName = operandType0.getSqlTypeName();
+ return switch (typeName) {
+ case DATE, TIMESTAMP ->
+ // Return TIMESTAMP
+ UserDefinedFunctionUtils.nullableTimestampUDT;
+ case TIME ->
+ // Return TIME
+ UserDefinedFunctionUtils.nullableTimeUDT;
+ default -> throw new IllegalArgumentException("Unsupported type: " + typeName);
+ };
+ };
+ }
}
diff --git a/core/src/main/java/org/opensearch/sql/calcite/udf/mathUDF/DivideFunction.java b/core/src/main/java/org/opensearch/sql/calcite/udf/mathUDF/DivideFunction.java
new file mode 100644
index 00000000000..3f5a5dcc670
--- /dev/null
+++ b/core/src/main/java/org/opensearch/sql/calcite/udf/mathUDF/DivideFunction.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.sql.calcite.udf.mathUDF;
+
+import org.opensearch.sql.calcite.udf.UserDefinedFunction;
+import org.opensearch.sql.calcite.utils.MathUtils;
+import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
+
+public class DivideFunction implements UserDefinedFunction {
+
+ @Override
+ public Object eval(Object... args) {
+ if (UserDefinedFunctionUtils.containsNull(args)) {
+ return null;
+ }
+
+ Number dividend = (Number) args[0];
+ Number divisor = (Number) args[1];
+
+ if (divisor.doubleValue() == 0) {
+ return null;
+ }
+
+ if (MathUtils.isIntegral(dividend) && MathUtils.isIntegral(divisor)) {
+ long result = dividend.longValue() / divisor.longValue();
+ return MathUtils.coerceToWidestIntegralType(dividend, divisor, result);
+ }
+ double result = dividend.doubleValue() / divisor.doubleValue();
+ return MathUtils.coerceToWidestFloatingType(dividend, divisor, result);
+ }
+}
diff --git a/core/src/main/java/org/opensearch/sql/calcite/udf/mathUDF/ModFunction.java b/core/src/main/java/org/opensearch/sql/calcite/udf/mathUDF/ModFunction.java
index 3c1d2817108..eac19d88791 100644
--- a/core/src/main/java/org/opensearch/sql/calcite/udf/mathUDF/ModFunction.java
+++ b/core/src/main/java/org/opensearch/sql/calcite/udf/mathUDF/ModFunction.java
@@ -7,6 +7,7 @@
import java.math.BigDecimal;
import org.opensearch.sql.calcite.udf.UserDefinedFunction;
+import org.opensearch.sql.calcite.utils.MathUtils;
/**
* Calculate the remainder of x divided by y
@@ -32,33 +33,22 @@ public Object eval(Object... args) {
arg0.getClass().getSimpleName(), arg1.getClass().getSimpleName()));
}
- // TODO: This precision check is arbitrary.
- if (Math.abs(num1.doubleValue()) < 0.0000001) {
+ if (num1.doubleValue() == 0) {
return null;
}
- if (isIntegral(num0) && isIntegral(num1)) {
+ if (MathUtils.isIntegral(num0) && MathUtils.isIntegral(num1)) {
long l0 = num0.longValue();
long l1 = num1.longValue();
// It returns negative values when l0 is negative
long result = l0 % l1;
// Return the wider type between l0 and l1
- if (num0 instanceof Long || num1 instanceof Long) {
- return result;
- }
- return (int) result;
+ return MathUtils.coerceToWidestIntegralType(num0, num1, result);
}
BigDecimal b0 = new BigDecimal(num0.toString());
BigDecimal b1 = new BigDecimal(num1.toString());
BigDecimal result = b0.remainder(b1);
- if (num0 instanceof Double || num1 instanceof Double) {
- return result.doubleValue();
- }
- return result.floatValue();
- }
-
- private boolean isIntegral(Number n) {
- return n instanceof Byte || n instanceof Short || n instanceof Integer || n instanceof Long;
+ return MathUtils.coerceToWidestFloatingType(num0, num1, result.doubleValue());
}
}
diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/BuiltinFunctionUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/BuiltinFunctionUtils.java
index fdd7738ca7d..8cbfd6d1723 100644
--- a/core/src/main/java/org/opensearch/sql/calcite/utils/BuiltinFunctionUtils.java
+++ b/core/src/main/java/org/opensearch/sql/calcite/utils/BuiltinFunctionUtils.java
@@ -6,7 +6,6 @@
package org.opensearch.sql.calcite.utils;
import static java.lang.Math.E;
-import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.*;
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.getLegacyTypeName;
import static org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils.*;
@@ -90,6 +89,7 @@
import org.opensearch.sql.calcite.udf.datetimeUDF.YearWeekFunction;
import org.opensearch.sql.calcite.udf.mathUDF.CRC32Function;
import org.opensearch.sql.calcite.udf.mathUDF.ConvFunction;
+import org.opensearch.sql.calcite.udf.mathUDF.DivideFunction;
import org.opensearch.sql.calcite.udf.mathUDF.EulerFunction;
import org.opensearch.sql.calcite.udf.mathUDF.ModFunction;
import org.opensearch.sql.calcite.udf.mathUDF.SqrtFunction;
@@ -138,7 +138,8 @@ static SqlOperator translate(String op) {
case "*":
return SqlStdOperatorTable.MULTIPLY;
case "/":
- return SqlStdOperatorTable.DIVIDE;
+ return TransferUserDefinedFunction(
+ DivideFunction.class, "/", ReturnTypes.QUOTIENT_NULLABLE);
// Built-in String Functions
case "ASCII":
return SqlStdOperatorTable.ASCII;
@@ -216,8 +217,7 @@ static SqlOperator translate(String op) {
// The MOD function in PPL supports floating-point parameters, e.g., MOD(5.5, 2) = 1.5,
// MOD(3.1, 2.1) = 1.1,
// whereas SqlStdOperatorTable.MOD supports only integer / long parameters.
- return TransferUserDefinedFunction(
- ModFunction.class, "MOD", getLeastRestrictiveReturnTypeAmongArgsAt(List.of(0, 1)));
+ return TransferUserDefinedFunction(ModFunction.class, "MOD", ReturnTypes.LEAST_RESTRICTIVE);
case "PI":
return SqlStdOperatorTable.PI;
case "POW", "POWER":
@@ -265,9 +265,7 @@ static SqlOperator translate(String op) {
DateAddSubFunction.class, "DATE_SUB", timestampInference);
case "ADDTIME", "SUBTIME":
return TransferUserDefinedFunction(
- TimeAddSubFunction.class,
- capitalOP,
- UserDefinedFunctionUtils.getReturnTypeForTimeAddSub());
+ TimeAddSubFunction.class, capitalOP, TimeAddSubFunction.getReturnTypeForTimeAddSub());
case "DAY_OF_WEEK", "DAYOFWEEK":
return TransferUserDefinedFunction(
DayOfWeekFunction.class, capitalOP, INTEGER_FORCE_NULLABLE);
diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/MathUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/MathUtils.java
new file mode 100644
index 00000000000..b211dd9223a
--- /dev/null
+++ b/core/src/main/java/org/opensearch/sql/calcite/utils/MathUtils.java
@@ -0,0 +1,52 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.sql.calcite.utils;
+
+public class MathUtils {
+ public static boolean isIntegral(Number n) {
+ return n instanceof Byte || n instanceof Short || n instanceof Integer || n instanceof Long;
+ }
+
+ /**
+ * Converts a long value to the least restrictive integral type based on the types of two input
+ * numbers.
+ *
+ *
This is useful when performing operations like division or modulo and you want to preserve
+ * the most appropriate type (e.g., int vs long).
+ *
+ * @param a one operand involved in the operation
+ * @param b another operand involved in the operation
+ * @param value the result to convert to the least restrictive integral type
+ * @return the value converted to Byte, Short, Integer, or Long
+ */
+ public static Number coerceToWidestIntegralType(Number a, Number b, long value) {
+ if (a instanceof Long || b instanceof Long) {
+ return value;
+ } else if (a instanceof Integer || b instanceof Integer) {
+ return (int) value;
+ } else if (a instanceof Short || b instanceof Short) {
+ return (short) value;
+ } else {
+ return (byte) value;
+ }
+ }
+
+ /**
+ * Converts a double value to the least restrictive floating type based on the types of two input
+ *
+ * @param a one operand involved in the operation
+ * @param b another operand involved in the operation
+ * @param value the result to convert to the least restrictive floating type
+ * @return the value converted to Float or Double
+ */
+ public static Number coerceToWidestFloatingType(Number a, Number b, double value) {
+ if (a instanceof Double || b instanceof Double) {
+ return value;
+ } else {
+ return (float) value;
+ }
+ }
+}
diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java
index 3758d1c73fa..176de3474a1 100644
--- a/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java
+++ b/core/src/main/java/org/opensearch/sql/calcite/utils/UserDefinedFunctionUtils.java
@@ -104,37 +104,6 @@ public static SqlOperator TransferUserDefinedFunction(
udfFunction);
}
- /**
- * Infer return argument type as the widest return type among arguments as specified positions.
- * E.g. (Integer, Long) -> Long; (Double, Float, SHORT) -> Double
- *
- * @param positions positions where the return type should be inferred from
- * @param nullable whether the returned value is nullable
- * @return The type inference
- */
- public static SqlReturnTypeInference getLeastRestrictiveReturnTypeAmongArgsAt(
- List