Skip to content

Commit

Permalink
fix equalNullSafe
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Jul 2, 2015
1 parent 1b0c8e6 commit 04ef4b0
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
} else if (l == null || r == null) {
false
} else {
l == r
if (left.dataType != BinaryType) l == r
else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]])
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.{Date, Timestamp}

import scala.collection.immutable.HashSet

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types.{IntegerType, BooleanType}
import org.apache.spark.sql.types.{Decimal, IntegerType, BooleanType}


class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -126,8 +123,6 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
val two = Literal(2)
val three = Literal(3)
val nl = Literal(null)
val s = Seq(one, two)
val nullS = Seq(one, two, null)
checkEvaluation(InSet(one, hS), true)
checkEvaluation(InSet(two, hS), true)
checkEvaluation(InSet(two, nS), true)
Expand All @@ -137,43 +132,47 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
}

private def binaryComparisonTest(
name: String,
op: (Expression, Expression) => Expression,
result: Seq[Boolean]): Unit = {
val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_))
val largeValues = Seq(2, Decimal(2), Array(2.toByte), "b").map(Literal(_))
val equalValues = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_))
test("BinaryComparison: " + name) {
for (i <- 0 until result.length) {
checkEvaluation(op(smallValues(i), largeValues(i)), result(0))
checkEvaluation(op(smallValues(i), equalValues(i)), result(1))
checkEvaluation(op(largeValues(i), smallValues(i)), result(2))
}
}
}

test("BinaryComparison") {
val row = create_row(1, 2, 3, null, 3, null)
val c1 = 'a.int.at(0)
val c2 = 'a.int.at(1)
val c3 = 'a.int.at(2)
val c4 = 'a.int.at(3)
val c5 = 'a.int.at(4)
val c6 = 'a.int.at(5)
binaryComparisonTest("<", LessThan, Seq(true, false, false))
binaryComparisonTest("<=", LessThanOrEqual, Seq(true, true, false))
binaryComparisonTest(">", GreaterThan, Seq(false, false, true))
binaryComparisonTest(">=", GreaterThanOrEqual, Seq(false, true, true))
binaryComparisonTest("===", EqualTo, Seq(false, true, false))
binaryComparisonTest("<=>", EqualNullSafe, Seq(false, true, false))

test("BinaryComparison: null test") {
val normalInt = Literal(1)
val nullInt = Literal.create(null, IntegerType)

def nullTest(op: (Expression, Expression) => Expression): Unit = {
checkEvaluation(op(normalInt, nullInt), null)
checkEvaluation(op(nullInt, normalInt), null)
checkEvaluation(op(nullInt, nullInt), null)
}

checkEvaluation(LessThan(c1, c4), null, row)
checkEvaluation(LessThan(c1, c2), true, row)
checkEvaluation(LessThan(c1, Literal.create(null, IntegerType)), null, row)
checkEvaluation(LessThan(Literal.create(null, IntegerType), c2), null, row)
checkEvaluation(
LessThan(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row)

checkEvaluation(c1 < c2, true, row)
checkEvaluation(c1 <= c2, true, row)
checkEvaluation(c1 > c2, false, row)
checkEvaluation(c1 >= c2, false, row)
checkEvaluation(c1 === c2, false, row)
checkEvaluation(c1 !== c2, true, row)
checkEvaluation(c4 <=> c1, false, row)
checkEvaluation(c1 <=> c4, false, row)
checkEvaluation(c4 <=> c6, true, row)
checkEvaluation(c3 <=> c5, true, row)
checkEvaluation(Literal(true) <=> Literal.create(null, BooleanType), false, row)
checkEvaluation(Literal.create(null, BooleanType) <=> Literal(true), false, row)

val d1 = DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-01"))
val d2 = DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-02"))
checkEvaluation(Literal(d1) < Literal(d2), true)

val ts1 = new Timestamp(12)
val ts2 = new Timestamp(123)
checkEvaluation(Literal("ab") < Literal("abc"), true)
checkEvaluation(Literal(ts1) < Literal(ts2), true)
nullTest(LessThan)
nullTest(LessThanOrEqual)
nullTest(GreaterThan)
nullTest(GreaterThanOrEqual)
nullTest(EqualTo)

checkEvaluation(normalInt <=> nullInt, false)
checkEvaluation(nullInt <=> normalInt, false)
checkEvaluation(nullInt <=> nullInt, true)
}
}

0 comments on commit 04ef4b0

Please sign in to comment.