Skip to content

Commit

Permalink
[SPARK-36632][SQL] DivideYMInterval and DivideDTInterval should throw…
Browse files Browse the repository at this point in the history
… the same exception when divide by zero

### What changes were proposed in this pull request?
When dividing by zero, `DivideYMInterval` and `DivideDTInterval` output
```
java.lang.ArithmeticException
/ by zero
```
But, in ansi mode, `select 2 / 0` will output
```
org.apache.spark.SparkArithmeticException
divide by zero
```
The behavior looks not inconsistent.

### Why are the changes needed?
Make consistent behavior.

### Does this PR introduce _any_ user-facing change?
'Yes'.

### How was this patch tested?
New tests.

Closes apache#33889 from beliefer/SPARK-36632.

Lead-authored-by: gengjiaan <gengjiaan@360.cn>
Co-authored-by: beliefer <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit de0161a)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
2 people authored and catalinii committed Mar 4, 2022
1 parent 4f97db3 commit 1127d86
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 24 deletions.
Expand Up @@ -598,6 +598,17 @@ trait IntervalDivide {
}
}
}

def divideByZeroCheck(dataType: DataType, num: Any): Unit = dataType match {
case _: DecimalType =>
if (num.asInstanceOf[Decimal].isZero) throw QueryExecutionErrors.divideByZeroError()
case _ => if (num == 0) throw QueryExecutionErrors.divideByZeroError()
}

def divideByZeroCheckCodegen(dataType: DataType, value: String): String = dataType match {
case _: DecimalType => s"if ($value.isZero()) throw QueryExecutionErrors.divideByZeroError();"
case _ => s"if ($value == 0) throw QueryExecutionErrors.divideByZeroError();"
}
}

// Divide an year-month interval by a numeric
Expand Down Expand Up @@ -629,6 +640,7 @@ case class DivideYMInterval(

override def nullSafeEval(interval: Any, num: Any): Any = {
checkDivideOverflow(interval.asInstanceOf[Int], Int.MinValue, right, num)
divideByZeroCheck(right.dataType, num)
evalFunc(interval.asInstanceOf[Int], num)
}

Expand All @@ -650,17 +662,24 @@ case class DivideYMInterval(
// Similarly to non-codegen code. The result of `divide(Int, Long, ...)` must fit to `Int`.
// Casting to `Int` is safe here.
s"""
|${divideByZeroCheckCodegen(right.dataType, n)}
|$checkIntegralDivideOverflow
|${ev.value} = ($javaType)$math.divide($m, $n, java.math.RoundingMode.HALF_UP);
""".stripMargin)
case _: DecimalType =>
defineCodeGen(ctx, ev, (m, n) =>
s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" +
".setScale(0, java.math.RoundingMode.HALF_UP).intValueExact()")
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|${divideByZeroCheckCodegen(right.dataType, n)}
|${ev.value} = ((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()
| .setScale(0, java.math.RoundingMode.HALF_UP).intValueExact();
""".stripMargin)
case _: FractionalType =>
val math = classOf[DoubleMath].getName
defineCodeGen(ctx, ev, (m, n) =>
s"$math.roundToInt($m / (double)$n, java.math.RoundingMode.HALF_UP)")
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|${divideByZeroCheckCodegen(right.dataType, n)}
|${ev.value} = $math.roundToInt($m / (double)$n, java.math.RoundingMode.HALF_UP);
""".stripMargin)
}

override def toString: String = s"($left / $right)"
Expand Down Expand Up @@ -696,6 +715,7 @@ case class DivideDTInterval(

override def nullSafeEval(interval: Any, num: Any): Any = {
checkDivideOverflow(interval.asInstanceOf[Long], Long.MinValue, right, num)
divideByZeroCheck(right.dataType, num)
evalFunc(interval.asInstanceOf[Long], num)
}

Expand All @@ -711,17 +731,24 @@ case class DivideDTInterval(
|""".stripMargin
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|${divideByZeroCheckCodegen(right.dataType, n)}
|$checkIntegralDivideOverflow
|${ev.value} = $math.divide($m, $n, java.math.RoundingMode.HALF_UP);
""".stripMargin)
case _: DecimalType =>
defineCodeGen(ctx, ev, (m, n) =>
s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" +
".setScale(0, java.math.RoundingMode.HALF_UP).longValueExact()")
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|${divideByZeroCheckCodegen(right.dataType, n)}
|${ev.value} = ((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()
| .setScale(0, java.math.RoundingMode.HALF_UP).longValueExact();
""".stripMargin)
case _: FractionalType =>
val math = classOf[DoubleMath].getName
defineCodeGen(ctx, ev, (m, n) =>
s"$math.roundToLong($m / (double)$n, java.math.RoundingMode.HALF_UP)")
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
|${divideByZeroCheckCodegen(right.dataType, n)}
|${ev.value} = $math.roundToLong($m / (double)$n, java.math.RoundingMode.HALF_UP);
""".stripMargin)
}

override def toString: String = s"($left / $right)"
Expand Down
Expand Up @@ -412,8 +412,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

Seq(
(Period.ofMonths(1), 0) -> "/ by zero",
(Period.ofMonths(Int.MinValue), 0d) -> "input is infinite or NaN",
(Period.ofMonths(1), 0) -> "divide by zero",
(Period.ofMonths(Int.MinValue), 0d) -> "divide by zero",
(Period.ofMonths(-100), Float.NaN) -> "input is infinite or NaN"
).foreach { case ((period, num), expectedErrMsg) =>
checkExceptionInExpression[ArithmeticException](
Expand Down Expand Up @@ -447,8 +447,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

Seq(
(Duration.ofDays(1), 0) -> "/ by zero",
(Duration.ofMillis(Int.MinValue), 0d) -> "input is infinite or NaN",
(Duration.ofDays(1), 0) -> "divide by zero",
(Duration.ofMillis(Int.MinValue), 0d) -> "divide by zero",
(Duration.ofSeconds(-100), Float.NaN) -> "input is infinite or NaN"
).foreach { case ((period, num), expectedErrMsg) =>
checkExceptionInExpression[ArithmeticException](
Expand Down
Expand Up @@ -209,8 +209,8 @@ select interval '2 seconds' / 0
-- !query schema
struct<>
-- !query output
java.lang.ArithmeticException
/ by zero
org.apache.spark.SparkArithmeticException
divide by zero


-- !query
Expand Down Expand Up @@ -242,8 +242,8 @@ select interval '2' year / 0
-- !query schema
struct<>
-- !query output
java.lang.ArithmeticException
/ by zero
org.apache.spark.SparkArithmeticException
divide by zero


-- !query
Expand Down
Expand Up @@ -203,8 +203,8 @@ select interval '2 seconds' / 0
-- !query schema
struct<>
-- !query output
java.lang.ArithmeticException
/ by zero
org.apache.spark.SparkArithmeticException
divide by zero


-- !query
Expand Down Expand Up @@ -236,8 +236,8 @@ select interval '2' year / 0
-- !query schema
struct<>
-- !query output
java.lang.ArithmeticException
/ by zero
org.apache.spark.SparkArithmeticException
divide by zero


-- !query
Expand Down
Expand Up @@ -2737,7 +2737,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
Seq((Period.ofYears(9999), 0)).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e.isInstanceOf[ArithmeticException])
assert(e.getMessage.contains("/ by zero"))
assert(e.getMessage.contains("divide by zero"))

val e2 = intercept[SparkException] {
Seq((Period.ofYears(9999), 0d)).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e2.isInstanceOf[ArithmeticException])
assert(e2.getMessage.contains("divide by zero"))

val e3 = intercept[SparkException] {
Seq((Period.ofYears(9999), BigDecimal(0))).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e3.isInstanceOf[ArithmeticException])
assert(e3.getMessage.contains("divide by zero"))
}

test("SPARK-34875: divide day-time interval by numeric") {
Expand Down Expand Up @@ -2772,7 +2784,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
Seq((Duration.ofDays(9999), 0)).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e.isInstanceOf[ArithmeticException])
assert(e.getMessage.contains("/ by zero"))
assert(e.getMessage.contains("divide by zero"))

val e2 = intercept[SparkException] {
Seq((Duration.ofDays(9999), 0d)).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e2.isInstanceOf[ArithmeticException])
assert(e2.getMessage.contains("divide by zero"))

val e3 = intercept[SparkException] {
Seq((Duration.ofDays(9999), BigDecimal(0))).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e3.isInstanceOf[ArithmeticException])
assert(e3.getMessage.contains("divide by zero"))
}

test("SPARK-34896: return day-time interval from dates subtraction") {
Expand Down

0 comments on commit 1127d86

Please sign in to comment.