From 09fcf96b8f881988a4bc7fe26a3f6ed12dfb6adb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 23 Jun 2015 23:11:42 -0700 Subject: [PATCH] [SPARK-8371] [SQL] improve unit test for MaxOf and MinOf and fix bugs a follow up of https://github.com/apache/spark/pull/6813 Author: Wenchen Fan Closes #6825 from cloud-fan/cg and squashes the following commits: 43170cc [Wenchen Fan] fix bugs in code gen --- .../expressions/codegen/CodeGenerator.scala | 4 +- .../ArithmeticExpressionSuite.scala | 46 +++++++++++++------ 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index bd5475d2066fc..47c5455435ec6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -175,8 +175,10 @@ class CodeGenContext { * Generate code for compare expression in Java */ def genComp(dataType: DataType, c1: String, c2: String): String = dataType match { + // java boolean doesn't support > or < operator + case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))" // use c1 - c2 may overflow - case dt: DataType if isPrimitiveType(dt) => s"(int)($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" + case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" case other => s"$c1.compare($c2)" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 4bbbbe6c7f091..6c93698f8017b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType} +import org.apache.spark.sql.types.Decimal class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -123,23 +123,39 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } - test("MaxOf") { - checkEvaluation(MaxOf(1, 2), 2) - checkEvaluation(MaxOf(2, 1), 2) - checkEvaluation(MaxOf(1L, 2L), 2L) - checkEvaluation(MaxOf(2L, 1L), 2L) + test("MaxOf basic") { + testNumericDataTypes { convert => + val small = Literal(convert(1)) + val large = Literal(convert(2)) + checkEvaluation(MaxOf(small, large), convert(2)) + checkEvaluation(MaxOf(large, small), convert(2)) + checkEvaluation(MaxOf(Literal.create(null, small.dataType), large), convert(2)) + checkEvaluation(MaxOf(large, Literal.create(null, small.dataType)), convert(2)) + } + } - checkEvaluation(MaxOf(Literal.create(null, IntegerType), 2), 2) - checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2) + test("MaxOf for atomic type") { + checkEvaluation(MaxOf(true, false), true) + checkEvaluation(MaxOf("abc", "bcd"), "bcd") + checkEvaluation(MaxOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)), + Array(1.toByte, 3.toByte)) } - test("MinOf") { - checkEvaluation(MinOf(1, 2), 1) - checkEvaluation(MinOf(2, 1), 1) - checkEvaluation(MinOf(1L, 2L), 1L) - checkEvaluation(MinOf(2L, 1L), 1L) + test("MinOf basic") { + testNumericDataTypes { convert => + val small = Literal(convert(1)) + val large = Literal(convert(2)) + checkEvaluation(MinOf(small, large), convert(1)) + checkEvaluation(MinOf(large, small), convert(1)) + checkEvaluation(MinOf(Literal.create(null, small.dataType), large), convert(2)) + checkEvaluation(MinOf(small, Literal.create(null, small.dataType)), convert(1)) + } + } - checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1) - checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1) + test("MinOf for atomic type") { + checkEvaluation(MinOf(true, false), false) + checkEvaluation(MinOf("abc", "bcd"), "abc") + checkEvaluation(MinOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)), + Array(1.toByte, 2.toByte)) } }